Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 6280598

Browse files
committed
revise model
1 parent 9d8756c commit 6280598

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

example/DoublePendulum/src/DoublePendulum.jl

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,36 @@ function train()
1818
end
1919

2020
m = Chain(
21-
FourierOperator(6=>6, (16, ), gelu),
22-
FourierOperator(6=>6, (16, ), gelu),
23-
FourierOperator(6=>6, (16, ), gelu),
21+
FourierOperator(6=>6, (16, ), relu),
22+
FourierOperator(6=>6, (16, ), relu),
23+
FourierOperator(6=>6, (16, ), relu),
24+
FourierOperator(6=>6, (16, ), relu),
2425
FourierOperator(6=>6, (16, )),
2526
) |> device
2627

2728
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
2829

30+
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
31+
2932
loader_train, loader_test = get_dataloader()
3033

34+
data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
35+
36+
loss_bounds = [0.3, 0.05, 0.01]
3137
function validate()
32-
validation_losses = [loss(device(𝐱), device(𝐲)) for (𝐱, 𝐲) in loader_test]
33-
@info "loss: $(sum(validation_losses)/length(loader_test))"
38+
validation_loss = sum(loss(device(𝐱), device(𝐲)) for (𝐱, 𝐲) in loader_test)/length(loader_test)
39+
@info "loss: $validation_loss"
40+
41+
isempty(loss_bounds) && return
42+
if validation_loss < loss_bounds[1]
43+
@warn "change η"
44+
opt.os[2].eta /= 2
45+
popfirst!(loss_bounds)
46+
end
3447
end
3548

36-
data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
37-
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-4))
38-
call_back = Flux.throttle(validate, 0.5, leading=false, trailing=true)
39-
Flux.@epochs 500 @time(Flux.train!(loss, params(m), data, opt, cb=call_back))
49+
call_back = Flux.throttle(validate, 1, leading=false, trailing=true)
50+
Flux.@epochs 300 @time(Flux.train!(loss, params(m), data, opt, cb=call_back))
4051
end
4152

4253
end

0 commit comments

Comments
 (0)