Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit a01ffe1

Browse files
committed
add gradient
1 parent 21e04b9 commit a01ffe1

File tree

2 files changed

+26
-13
lines changed

2 files changed

+26
-13
lines changed

example/DoublePendulum/src/DoublePendulum.jl

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function update_model!(model_file_path, model)
1515
@warn "model updated!"
1616
end
1717

18-
function train(; loss_bounds=[0.05])
18+
function train(; loss_bounds=[])
1919
if has_cuda()
2020
@info "CUDA is on"
2121
device = gpu
@@ -25,15 +25,29 @@ function train(; loss_bounds=[0.05])
2525
end
2626

2727
m = Chain(
28-
FourierOperator(1=>64, (6, 64, ), relu),
29-
FourierOperator(64=>64, (6, 64, ), relu),
30-
FourierOperator(64=>64, (6, 64, ), relu),
31-
FourierOperator(64=>1, (6, 64, )),
28+
Dense(1, 64, gelu),
29+
FourierOperator(64=>64, (12, ), gelu),
30+
FourierOperator(64=>64, (12, ), gelu),
31+
FourierOperator(64=>64, (12, ), gelu),
32+
FourierOperator(64=>64, (12, ), gelu),
33+
FourierOperator(64=>64, (12, ), gelu),
34+
FourierOperator(64=>64, (12, ), gelu),
35+
FourierOperator(64=>64, (12, ), gelu),
36+
FourierOperator(64=>64, (12, ), gelu),
37+
FourierOperator(64=>64, (12, ), gelu),
38+
FourierOperator(64=>64, (12, ), gelu),
39+
FourierOperator(64=>64, (12, ), gelu),
40+
FourierOperator(64=>64, (12, ), gelu),
41+
FourierOperator(64=>64, (12, ), gelu),
42+
FourierOperator(64=>64, (12, ), gelu),
43+
FourierOperator(64=>64, (12, ), gelu),
44+
FourierOperator(64=>64, (12, )),
45+
Dense(64, 1)
3246
) |> device
3347

3448
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
3549

36-
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-2))
50+
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
3751

3852
loader_train, loader_test = get_dataloader()
3953

example/DoublePendulum/src/data.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,18 @@ function get_data(; i=0, n=-1)
3333
)
3434
data = (n < 0) ? collect(Matrix(df)') : collect(Matrix(df)')[:, 1:n]
3535

36-
data /= maximum(data)
37-
3836
return Float32.(data)
3937
end
4038

41-
function get_dataloader(; i=0, n_train=15734, n_test=2048, Δn=1024, batchsize=100)
42-
x = reshape(get_data(; i=i, n=-1), :) # size==(6, 17782)
43-
𝐱 = reshape(vcat([x[i:(i+6Δn-1)] for i in 1:6:(length(x)-6(Δn-1))]...), 1, 6, 1024, :)
39+
function get_dataloader(; i=0, n_train=15733, n_test=2048, Δn=1, batchsize=100)
40+
𝐱 = get_data(i=i, n=-1) # size==(6, 17782)
41+
∇𝐱 = 𝐱[:, (1+Δn):end] - 𝐱[:, 1:(end-Δn)]
42+
𝐱 = reshape(vcat(𝐱[:, 1:(end-Δn)], ∇𝐱), 1, 12, :)
4443

45-
𝐱_train, 𝐲_train = 𝐱[:, :, :, 1:(n_train-Δn)], 𝐱[:, :, :, 1+Δn:n_train]
44+
𝐱_train, 𝐲_train = 𝐱[:, :, 1:(n_train-1)], 𝐱[:, :, 2:n_train]
4645
loader_train = Flux.DataLoader((𝐱_train, 𝐲_train), batchsize=batchsize, shuffle=true)
4746

48-
𝐱_test, 𝐲_test = 𝐱[:, :, :, (end-n_test+1):(end-Δn)], 𝐱[:, :, :, (end-n_test+1+Δn):end]
47+
𝐱_test, 𝐲_test = 𝐱[:, :, (end-n_test+1):(end-1)], 𝐱[:, :, (end-n_test+2):end]
4948
loader_test = Flux.DataLoader((𝐱_test, 𝐲_test), batchsize=batchsize, shuffle=false)
5049

5150
return loader_train, loader_test

0 commit comments

Comments
 (0)