Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit a2e1a37

Browse files
committed
train MNO in 2-D and make sure loss<1e-2 about 5e-3 at epoch=35
1 parent 482119a commit a2e1a37

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

example/DoublePendulum/src/DoublePendulum.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function update_model!(model_file_path, model)
1515
@warn "model updated!"
1616
end
1717

18-
function train(; loss_bounds=[1, 0.2, 0.1, 0.05, 0.02])
18+
function train(; loss_bounds=[0.05])
1919
if has_cuda()
2020
@info "CUDA is on"
2121
device = gpu
@@ -25,10 +25,10 @@ function train(; loss_bounds=[1, 0.2, 0.1, 0.05, 0.02])
2525
end
2626

2727
m = Chain(
28-
FourierOperator(6=>64, (16, ), relu),
29-
FourierOperator(64=>64, (16, ), relu),
30-
FourierOperator(64=>64, (16, ), relu),
31-
FourierOperator(64=>6, (16, )),
28+
FourierOperator(1=>64, (6, 64, ), relu),
29+
FourierOperator(64=>64, (6, 64, ), relu),
30+
FourierOperator(64=>64, (6, 64, ), relu),
31+
FourierOperator(64=>1, (6, 64, )),
3232
) |> device
3333

3434
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]

example/DoublePendulum/src/data.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@ end
4242

4343
function get_dataloader(; i=0, n_train=15001, n_test=1501, Δn=1024, batchsize=100)
4444
x = reshape(get_double_pendulum_chaotic_data(; i=i, n=-1), :)
45-
𝐱 = reshape(vcat([x[i:(i+6Δn-1)] for i in 1:6:(length(x)-6(Δn-1))]...), 6, 1024, :)
45+
𝐱 = reshape(vcat([x[i:(i+6Δn-1)] for i in 1:6:(length(x)-6(Δn-1))]...), 1, 6, 1024, :)
4646

47-
𝐱_train, 𝐲_train = 𝐱[:, :, 1:(n_train-1)], 𝐱[:, :, 2:n_train]
47+
𝐱_train, 𝐲_train = 𝐱[:, :, :, 1:(n_train-1)], 𝐱[:, :, :, 2:n_train]
4848
loader_train = Flux.DataLoader((𝐱_train, 𝐲_train), batchsize=batchsize, shuffle=true)
4949

50-
𝐱_test, 𝐲_test = 𝐱[:, :, (end-n_test+1):(end-1)], 𝐱[:, :, (end-n_test+2):end]
50+
𝐱_test, 𝐲_test = 𝐱[:, :, :, (end-n_test+1):(end-1)], 𝐱[:, :, :, (end-n_test+2):end]
5151
loader_test = Flux.DataLoader((𝐱_test, 𝐲_test), batchsize=batchsize, shuffle=false)
5252

5353
return loader_train, loader_test

0 commit comments

Comments
 (0)