Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 1d0d7f8

Browse files
committed
revise model
1 parent 5043249 commit 1d0d7f8

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

example/DoublePendulum/src/DoublePendulum.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ include("data.jl")
88

99
__init__() = register_double_pendulum_chaotic()
1010

11-
function train()
11+
function train(; loss_bounds=[1, 0.2, 0.1, 0.05, 0.02])
1212
if has_cuda()
1313
@info "CUDA is on"
1414
device = gpu
@@ -18,22 +18,20 @@ function train()
1818
end
1919

2020
m = Chain(
21-
FourierOperator(6=>6, (16, ), relu),
22-
FourierOperator(6=>6, (16, ), relu),
23-
FourierOperator(6=>6, (16, ), relu),
24-
FourierOperator(6=>6, (16, ), relu),
25-
FourierOperator(6=>6, (16, )),
21+
FourierOperator(6=>64, (16, ), relu),
22+
FourierOperator(64=>64, (16, ), relu),
23+
FourierOperator(64=>64, (16, ), relu),
24+
FourierOperator(64=>6, (16, )),
2625
) |> device
2726

2827
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
2928

30-
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
29+
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-2))
3130

3231
loader_train, loader_test = get_dataloader()
3332

3433
data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
3534

36-
loss_bounds = [0.3, 0.05, 0.01]
3735
function validate()
3836
validation_loss = sum(loss(device(𝐱), device(𝐲)) for (𝐱, 𝐲) in loader_test)/length(loader_test)
3937
@info "loss: $validation_loss"

0 commit comments

Comments
 (0)