Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit a18258e

Browse files
committed
Refactor
1 parent c6cdf0b commit a18258e

File tree

4 files changed

+25
-51
lines changed

4 files changed

+25
-51
lines changed

example/Burgers/src/Burgers.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,19 @@ end
4444

4545
__init__() = register_burgers()
4646

47-
function train(; epochs=500)
48-
if has_cuda()
49-
@info "CUDA is on"
47+
function train(; cuda=true, η₀=1f-3, λ=1f-4, epochs=500)
48+
if cuda && CUDA.has_cuda()
5049
device = gpu
5150
CUDA.allowscalar(false)
51+
@info "Training on GPU"
5252
else
5353
device = cpu
54+
@info "Training on CPU"
5455
end
5556

5657
model = FourierNeuralOperator(ch=(2, 64, 64, 64, 64, 64, 128, 1), modes=(16, ), σ=gelu)
5758
data = get_dataloader()
58-
optimiser = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
59+
optimiser = Flux.Optimiser(WeightDecay(λ), Flux.ADAM(η₀))
5960
loss_func = l₂loss
6061

6162
learner = Learner(

example/DoublePendulum/src/DoublePendulum.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,19 @@ end
8181

8282
__init__() = register_double_pendulum_chaotic()
8383

84-
function train(; Δt=1, epochs=20)
85-
if has_cuda()
86-
@info "CUDA is on"
84+
function train(; cuda=true, Δt=1, η₀=1f-3, λ=1f-4, epochs=20)
85+
if cuda && CUDA.has_cuda()
8786
device = gpu
8887
CUDA.allowscalar(false)
88+
@info "Training on GPU"
8989
else
9090
device = cpu
91+
@info "Training on CPU"
9192
end
9293

9394
model = FourierNeuralOperator(ch=(2, 64, 64, 64, 64, 64, 128, 2), modes=(4, 16), σ=gelu)
9495
data = get_dataloader(Δt=Δt)
95-
optimiser = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
96+
optimiser = Flux.Optimiser(WeightDecay(λ), Flux.ADAM(η₀))
9697
loss_func = l₂loss
9798

9899
learner = Learner(

example/FlowOverCircle/src/FlowOverCircle.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,19 @@ function get_dataloader(; ts::AbstractRange=LinRange(100, 11000, 10000), ratio::
4848
return loader_train, loader_test
4949
end
5050

51-
function train(; epochs=50)
52-
if has_cuda()
53-
@info "CUDA is on"
51+
function train(; cuda=true, η₀=1f-3, λ=1f-4, epochs=50)
52+
if cuda && CUDA.has_cuda()
5453
device = gpu
5554
CUDA.allowscalar(false)
55+
@info "Training on GPU"
5656
else
5757
device = cpu
58+
@info "Training on CPU"
5859
end
5960

6061
model = MarkovNeuralOperator(ch=(1, 64, 64, 64, 64, 64, 1), modes=(24, 24), σ=gelu)
6162
data = get_dataloader()
62-
optimiser = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
63+
optimiser = Flux.Optimiser(WeightDecay(λ), Flux.ADAM(η₀))
6364
loss_func = l₂loss
6465

6566
learner = Learner(
@@ -73,17 +74,17 @@ function train(; epochs=50)
7374
return learner
7475
end
7576

76-
function train_gno(; epochs=50)
77-
if has_cuda()
78-
@info "CUDA is on"
77+
function train_gno(; cuda=true, η₀=1f-3, λ=1f-4, epochs=50)
78+
if cuda && CUDA.has_cuda()
7979
device = gpu
8080
CUDA.allowscalar(false)
81+
@info "Training on GPU"
8182
else
8283
device = cpu
84+
@info "Training on CPU"
8385
end
8486

8587
featured_graph = FeaturedGraph(grid([96, 64]))
86-
8788
model = Chain(
8889
Dense(1, 16),
8990
WithGraph(featured_graph, GraphKernel(Dense(2*16, 16, gelu), 16)),
@@ -93,7 +94,7 @@ function train_gno(; epochs=50)
9394
Dense(16, 1),
9495
)
9596
data = get_dataloader(batchsize=16, flatten=true)
96-
optimiser = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
97+
optimiser = Flux.Optimiser(WeightDecay(λ), Flux.ADAM(η₀))
9798
loss_func = l₂loss
9899

99100
learner = Learner(

example/SuperResolution/src/SuperResolution.jl

Lines changed: 5 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -69,18 +69,19 @@ function fit!(learner, nepochs::Int)
6969
fit!(learner, nepochs, (learner.data.training, learner.data.validation, learner.data.testing))
7070
end
7171

72-
function train(; epochs=50)
73-
if has_cuda()
74-
@info "CUDA is on"
72+
function train(; cuda=true, η₀=1f-3, λ=1f-4, epochs=50)
73+
if cuda && CUDA.has_cuda()
7574
device = gpu
7675
CUDA.allowscalar(false)
76+
@info "Training on GPU"
7777
else
7878
device = cpu
79+
@info "Training on CPU"
7980
end
8081

8182
model = MarkovNeuralOperator(ch=(1, 64, 64, 64, 64, 64, 1), modes=(24, 24), σ=gelu)
8283
data = get_dataloader()
83-
optimiser = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
84+
optimiser = Flux.Optimiser(WeightDecay(λ), Flux.ADAM(η₀))
8485
loss_func = l₂loss
8586

8687
learner = Learner(
@@ -101,34 +102,4 @@ function get_model()
101102
return BSON.load(joinpath(model_path, model_file), @__MODULE__)[:model]
102103
end
103104

104-
# using NeuralOperators
105-
# using Flux
106-
# using Flux.Losses: mse
107-
# using Flux.Data: DataLoader
108-
# using GeometricFlux
109-
# using Graphs
110-
# using CUDA
111-
# using JLD2
112-
# using ProgressMeter: Progress, next!
113-
114-
# include("data.jl")
115-
# include("models.jl")
116-
117-
# function update_model!(model_file_path, model)
118-
# model = cpu(model)
119-
# jldsave(model_file_path; model)
120-
# @info "model updated!"
121-
# end
122-
123-
# function get_model()
124-
# f = jldopen(joinpath(@__DIR__, "../model/model.jld2"))
125-
# model = f["model"]
126-
# close(f)
127-
128-
# return model
129-
# end
130-
131-
# loss(m, 𝐱, 𝐲) = mse(m(𝐱), 𝐲)
132-
# loss(m, loader::DataLoader, device) = sum(loss(m, 𝐱 |> device, 𝐲 |> device) for (𝐱, 𝐲) in loader)/length(loader)
133-
134105
end # module

0 commit comments

Comments
 (0)