Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 97772af

Browse files
committed
revise model and save model
1 parent 1d0d7f8 commit 97772af

File tree

3 files changed

+17
-4
lines changed

3 files changed

+17
-4
lines changed

example/DoublePendulum/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
77
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
88
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
99
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
10+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
1011
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
1112
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
1213
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781"

example/DoublePendulum/model/.gitkeep

Whitespace-only changes.

example/DoublePendulum/src/DoublePendulum.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,19 @@ module DoublePendulum
33
using NeuralOperators
44
using Flux
55
using CUDA
6+
using JLD2
67

78
include("data.jl")
89

910
__init__() = register_double_pendulum_chaotic()
1011

11-
function train(; loss_bounds=[1, 0.2, 0.1, 0.05, 0.02])
12+
function update_model!(model_file_path, model)
13+
model = cpu(model)
14+
jldsave(model_file_path; model)
15+
@warn "model updated!"
16+
end
17+
18+
function train(; loss_bounds=[1, 0.3, 0.1, 0.05])
1219
if has_cuda()
1320
@info "CUDA is on"
1421
device = gpu
@@ -21,6 +28,7 @@ function train(; loss_bounds=[1, 0.2, 0.1, 0.05, 0.02])
2128
FourierOperator(6=>64, (16, ), relu),
2229
FourierOperator(64=>64, (16, ), relu),
2330
FourierOperator(64=>64, (16, ), relu),
31+
FourierOperator(64=>64, (16, ), relu),
2432
FourierOperator(64=>6, (16, )),
2533
) |> device
2634

@@ -32,20 +40,24 @@ function train(; loss_bounds=[1, 0.2, 0.1, 0.05, 0.02])
3240

3341
data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
3442

43+
losses = Float32[]
3544
function validate()
3645
validation_loss = sum(loss(device(𝐱), device(𝐲)) for (𝐱, 𝐲) in loader_test)/length(loader_test)
3746
@info "loss: $validation_loss"
3847

48+
push!(losses, validation_loss)
49+
(losses[end] == minimum(losses)) && update_model!(joinpath(@__DIR__, "../model/model.jld2"), m)
50+
3951
isempty(loss_bounds) && return
4052
if validation_loss < loss_bounds[1]
4153
@warn "change η"
4254
opt.os[2].eta /= 2
4355
popfirst!(loss_bounds)
4456
end
4557
end
46-
47-
call_back = Flux.throttle(validate, 1, leading=false, trailing=true)
48-
Flux.@epochs 300 @time(Flux.train!(loss, params(m), data, opt, cb=call_back))
58+
call_back = Flux.throttle(validate, 10, leading=false, trailing=true)
59+
60+
Flux.@epochs 50 @time(Flux.train!(loss, params(m), data, opt, cb=call_back))
4961
end
5062

5163
end

0 commit comments

Comments
 (0)