Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit da12420

Browse files
committed
Move GNO
1 parent 54e2dec commit da12420

File tree

2 files changed

+46
-3
lines changed

2 files changed

+46
-3
lines changed

example/FlowOverCircle/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
66
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
77
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
88
FluxTraining = "7bf95e4d-ca32-48da-9824-f0dc5310474f"
9+
GeometricFlux = "7e08b658-56d3-11e9-2997-919d5b31e4ea"
10+
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
911
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1012
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1113
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"

example/FlowOverCircle/src/FlowOverCircle.jl

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module FlowOverCircle
22

33
using WaterLily, LinearAlgebra, ProgressMeter, MLUtils
4-
using NeuralOperators, Flux
4+
using NeuralOperators, Flux, GeometricFlux, Graphs
55
using CUDA, FluxTraining, BSON
66

77
function circle(n, m; Re=250) # copy from [WaterLily](https://github.com/weymouth/WaterLily.jl)
@@ -31,9 +31,16 @@ function gen_data(ts::AbstractRange)
3131
return 𝐩s
3232
end
3333

34-
function get_dataloader(; ts::AbstractRange=LinRange(100, 11000, 10000), ratio::Float64=0.95, batchsize=100)
34+
function get_dataloader(; ts::AbstractRange=LinRange(100, 11000, 10000), ratio::Float64=0.95, batchsize=100, flatten=false)
3535
data = gen_data(ts)
36-
data_train, data_test = splitobs(shuffleobs((𝐱=data[:, :, :, 1:end-1], 𝐲=data[:, :, :, 2:end])), at=ratio)
36+
𝐱, 𝐲 = data[:, :, :, 1:end-1], data[:, :, :, 2:end]
37+
n = length(ts) - 1
38+
39+
if flatten
40+
𝐱, 𝐲 = reshape(𝐱, 1, :, n), reshape(𝐲, 1, :, n)
41+
end
42+
43+
data_train, data_test = splitobs(shuffleobs((𝐱, 𝐲)), at=ratio)
3744

3845
loader_train = DataLoader(data_train, batchsize=batchsize, shuffle=true)
3946
loader_test = DataLoader(data_test, batchsize=batchsize, shuffle=false)
@@ -66,6 +73,40 @@ function train(; epochs=50)
6673
return learner
6774
end
6875

76+
function train_gno(; epochs=50)
77+
if has_cuda()
78+
@info "CUDA is on"
79+
device = gpu
80+
CUDA.allowscalar(false)
81+
else
82+
device = cpu
83+
end
84+
85+
featured_graph = FeaturedGraph(grid([96, 64]))
86+
87+
model = Chain(
88+
Dense(1, 16),
89+
WithGraph(featured_graph, GraphKernel(Dense(2*16, 16, gelu), 16)),
90+
WithGraph(featured_graph, GraphKernel(Dense(2*16, 16, gelu), 16)),
91+
WithGraph(featured_graph, GraphKernel(Dense(2*16, 16, gelu), 16)),
92+
WithGraph(featured_graph, GraphKernel(Dense(2*16, 16, gelu), 16)),
93+
Dense(16, 1),
94+
)
95+
data = get_dataloader(batchsize=16, flatten=true)
96+
optimiser = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
97+
loss_func = l₂loss
98+
99+
learner = Learner(
100+
model, data, optimiser, loss_func,
101+
ToDevice(device, device),
102+
Checkpointer(joinpath(@__DIR__, "../model/"))
103+
)
104+
105+
fit!(learner, epochs)
106+
107+
return learner
108+
end
109+
69110
function get_model()
70111
model_path = joinpath(@__DIR__, "../model/")
71112
model_file = readdir(model_path)[end]

0 commit comments

Comments
 (0)