Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit a0c5c07

Browse files
committed
add GNO example
fix
1 parent b64ea84 commit a0c5c07

File tree

4 files changed

+82
-10
lines changed

4 files changed

+82
-10
lines changed

example/SuperResolution/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@ uuid = "a8258e1f-331c-4af2-83e9-878628278453"
44
[deps]
55
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
66
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
7+
GeometricFlux = "7e08b658-56d3-11e9-2997-919d5b31e4ea"
8+
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
79
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
810
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
911
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
1012
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
1113
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781"
14+
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1215
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1316
WaterLily = "ed894a53-35f9-47f1-b17f-85db9237eebd"
1417

example/SuperResolution/src/SuperResolution.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@ using NeuralOperators
44
using Flux
55
using Flux.Losses: mse
66
using Flux.Data: DataLoader
7+
using GeometricFlux
8+
using Graphs
79
using CUDA
810
using JLD2
11+
using ProgressMeter: Progress, next!
912

1013
include("data.jl")
1114
include("models.jl")

example/SuperResolution/src/data.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,20 @@ function circle(n, m; Re=250)
1515
Simulation((n+2, m+2), [U, 0.], R; ν, body)
1616
end
1717

18-
function gen_data(ts::AbstractRange)
18+
function gen_data(ts::AbstractRange, T=Float32)
1919
n, m = 2 * 3(2^5), 2 * 2^6
2020
circ = circle(n, m)
2121

22-
𝐩s = Array{Float32}(undef, 1, n, m, length(ts))
22+
𝐩s = Array{T}(undef, 1, n, m, length(ts))
2323
for (i, t) in enumerate(ts)
2424
sim_step!(circ, t)
25-
𝐩s[1, :, :, i] .= Float32.(circ.flow.p)[2:end-1, 2:end-1]
25+
𝐩s[1, :, :, i] .= T.(circ.flow.p)[2:end-1, 2:end-1]
2626
end
2727

2828
return 𝐩s
2929
end
3030

31-
function get_dataloader(; ts::AbstractRange=LinRange(100, 11000, 10000), ratio::Float64=0.95, batchsize=100)
31+
function get_dataloader(; ts::AbstractRange=LinRange(100, 11000, 10000), ratio::Real=0.95, batchsize=100)
3232
data = gen_data(ts)
3333

3434
n_train, n_test = floor(Int, length(ts)*ratio), floor(Int, length(ts)*(1-ratio))

example/SuperResolution/src/models.jl

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function train_mno(; cuda=true, η=1f-3, λ=1f-4, epochs=50)
2626
opt = Flux.Optimiser(WeightDecay(λ), Flux.ADAM(η))
2727

2828
# parameters
29-
ps = Flux.params(model)
29+
ps = Flux.params(m)
3030

3131
# training
3232
min_loss = Inf32
@@ -37,10 +37,10 @@ function train_mno(; cuda=true, η=1f-3, λ=1f-4, epochs=50)
3737
progress = Progress(length(loader_train))
3838

3939
for (𝐱, 𝐲) in loader_train
40-
grad = gradient(() -> loss(model, 𝐱 |> device, 𝐲 |> device), ps)
40+
grad = gradient(() -> loss(m, 𝐱 |> device, 𝐲 |> device), ps)
4141
Flux.Optimise.update!(opt, ps, grad)
42-
train_loss = loss(model, loader_train, device)
43-
test_loss = loss(model, loader_test, device)
42+
train_loss = loss(m, loader_train, device)
43+
test_loss = loss(m, loader_test, device)
4444

4545
# progress meter
4646
next!(progress; showvalues=[
@@ -49,12 +49,78 @@ function train_mno(; cuda=true, η=1f-3, λ=1f-4, epochs=50)
4949
])
5050

5151
if test_loss min_loss
52-
update_model!(joinpath(@__DIR__, "../model/model.jld2"), m)
52+
update_model!(joinpath(@__DIR__, "../model/mno.jld2"), m)
53+
min_loss = test_loss
5354
end
5455

5556
train_steps += 1
5657
end
5758
end
5859

5960
return m
60-
end
61+
end
62+
63+
function train_gno(; channel=64, cuda=true, η=1f-3, λ=1f-4, epochs=50)
64+
# GPU config
65+
if cuda && CUDA.has_cuda()
66+
device = gpu
67+
CUDA.allowscalar(false)
68+
@info "Training on GPU"
69+
else
70+
device = cpu
71+
@info "Training on CPU"
72+
end
73+
74+
@info "gen data... "
75+
@time loader_train, loader_test = get_dataloader()
76+
77+
# build model
78+
g = grid([12, 8])
79+
fg = FeaturedGraph(g)
80+
81+
m = Chain(
82+
Dense(1, 64),
83+
WithGraph(fg, GraphKernel(Dense(2channel, channel, gelu), channel)),
84+
WithGraph(fg, GraphKernel(Dense(2channel, channel, gelu), channel)),
85+
WithGraph(fg, GraphKernel(Dense(2channel, channel, gelu), channel)),
86+
WithGraph(fg, GraphKernel(Dense(2channel, channel, gelu), channel)),
87+
Dense(64, 1),
88+
) |> device
89+
90+
# optimizer
91+
opt = Flux.Optimiser(WeightDecay(λ), Flux.ADAM(η))
92+
93+
# parameters
94+
ps = Flux.params(m)
95+
96+
# training
97+
min_loss = Inf32
98+
train_steps = 0
99+
@info "Start Training, total $(epochs) epochs"
100+
for epoch = 1:epochs
101+
@info "Epoch $(epoch)"
102+
progress = Progress(length(loader_train))
103+
104+
for (𝐱, 𝐲) in loader_train
105+
grad = gradient(() -> loss(m, 𝐱 |> device, 𝐲 |> device), ps)
106+
Flux.Optimise.update!(opt, ps, grad)
107+
train_loss = loss(m, loader_train, device)
108+
test_loss = loss(m, loader_test, device)
109+
110+
# progress meter
111+
next!(progress; showvalues=[
112+
(:train_loss, train_loss),
113+
(:test_loss, test_loss)
114+
])
115+
116+
if test_loss min_loss
117+
update_model!(joinpath(@__DIR__, "../model/gno.jld2"), m)
118+
min_loss = test_loss
119+
end
120+
121+
train_steps += 1
122+
end
123+
end
124+
125+
return m
126+
end

0 commit comments

Comments
 (0)