Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit a634b8b

Browse files
committed
after NS model successful
1 parent 664b657 commit a634b8b

File tree

3 files changed

+49
-56
lines changed

3 files changed

+49
-56
lines changed

example/DoublePendulum/notebook/data.jl

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,37 +14,41 @@ begin
1414
end
1515

1616
# ╔═╡ 5268feee-bda2-4612-9d4c-a1db424a11c7
17-
data = DoublePendulum.get_data(i=0, n=-1)[:, end-2048+1:end]
17+
data, _, _, _ = DoublePendulum.preprocess(DoublePendulum.get_data(i=0), ratio=1);
18+
# data = reshape(DoublePendulum.get_data(; i=0, n=1000), 1, 1, 6, :);
1819

1920
# ╔═╡ 4d0b08a4-8a54-41fd-997f-ad54d4c984cd
2021
m = DoublePendulum.get_model()
2122

2223
# ╔═╡ 794374ce-6674-481d-8a3b-04db0f32d233
2324
begin
24-
n = 10
25+
n = 100
2526

26-
ground_truth_data = 1 .- data[:, 1+n:1024+n]
27+
ground_truth_data = data[1, 1, :, 1:n]
2728

28-
inferenced_data = m(reshape(data[:, 1:1024], 1, 6, :, 1))
29-
for i in 1:n
30-
inferenced_data = m(inferenced_data)
31-
end
32-
inferenced_data = 1 .- reshape(inferenced_data, 6, :)
29+
# inferenced_data = m(reshape(data[:, 1:1024], 1, 6, :, 1))
30+
# for i in 1:n
31+
# inferenced_data = m(inferenced_data)
32+
# end
33+
# inferenced_data = 1 .- reshape(inferenced_data, 6, :)
3334
end
3435

3536
# ╔═╡ 9c8b3f8a-1b85-4c32-a416-ead51b244b94
3637
begin
37-
anim = @animate for i in 1:4:1024
38-
scatter(legend=false, xlim=(0, 1), ylim=(-0.5, 1), size=(600, 500))
39-
scatter!(
40-
inferenced_data[[2, 4, 6], i], inferenced_data[[1, 3, 5], i],
41-
color=[
42-
RGB([239, 71, 111]/255...),
43-
RGB([6, 214, 160]/255...),
44-
RGB([17, 138, 178]/255...)
45-
],
46-
markersize=8
38+
anim = @animate for i in 1:n
39+
scatter(
40+
legend=false, ticks=false,
41+
xlim=(0, 2500), ylim=(0, 2500), size=(600, 500)
4742
)
43+
# scatter!(
44+
# inferenced_data[[2, 4, 6], i], inferenced_data[[1, 3, 5], i],
45+
# color=[
46+
# RGB([239, 71, 111]/255...),
47+
# RGB([6, 214, 160]/255...),
48+
# RGB([17, 138, 178]/255...)
49+
# ],
50+
# markersize=8
51+
# )
4852
scatter!(
4953
ground_truth_data[[2, 4, 6], i], ground_truth_data[[1, 3, 5], i],
5054
color=[
@@ -54,7 +58,7 @@ begin
5458
],
5559
markersize=4
5660
)
57-
annotate!(0.1, -0.4, text("i=$i", :left))
61+
annotate!(50, 50, text("t=$i", :left))
5862
end
5963

6064
gif(anim)
@@ -66,4 +70,4 @@ end
6670
# ╠═5268feee-bda2-4612-9d4c-a1db424a11c7
6771
# ╠═4d0b08a4-8a54-41fd-997f-ad54d4c984cd
6872
# ╠═794374ce-6674-481d-8a3b-04db0f32d233
69-
# ╟─9c8b3f8a-1b85-4c32-a416-ead51b244b94
73+
# ╠═9c8b3f8a-1b85-4c32-a416-ead51b244b94

example/DoublePendulum/src/DoublePendulum.jl

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function update_model!(model_file_path, model)
1515
@warn "model updated!"
1616
end
1717

18-
function train(; loss_bounds=[])
18+
function train(; Δt=2)
1919
if has_cuda()
2020
@info "CUDA is on"
2121
device = gpu
@@ -25,33 +25,18 @@ function train(; loss_bounds=[])
2525
end
2626

2727
m = Chain(
28-
Dense(1, 64, gelu),
29-
FourierOperator(64=>64, (12, ), gelu),
30-
FourierOperator(64=>64, (12, ), gelu),
31-
FourierOperator(64=>64, (12, ), gelu),
32-
FourierOperator(64=>64, (12, ), gelu),
33-
FourierOperator(64=>64, (12, ), gelu),
34-
FourierOperator(64=>64, (12, ), gelu),
35-
FourierOperator(64=>64, (12, ), gelu),
36-
FourierOperator(64=>64, (12, ), gelu),
37-
FourierOperator(64=>64, (12, ), gelu),
38-
FourierOperator(64=>64, (12, ), gelu),
39-
FourierOperator(64=>64, (12, ), gelu),
40-
FourierOperator(64=>64, (12, ), gelu),
41-
FourierOperator(64=>64, (12, ), gelu),
42-
FourierOperator(64=>64, (12, ), gelu),
43-
FourierOperator(64=>64, (12, ), gelu),
44-
FourierOperator(64=>64, (12, )),
45-
Dense(64, 1)
28+
Dense(1, 350), # (1, 2, 6, :) -> (350, 2, 6, :)
29+
x -> reshape(x, 1, 60, 70, :), # (350, 2, 6, :) -> (1, 60, 70, :)
30+
MarkovNeuralOperator(),
31+
x -> reshape(x, 350, 2, 6, :), # (1, 60, 70, :) -> (350, 2, 6, :)
32+
Dense(350, 1), # (350, 2, 6, :) -> (1, 2, 6, :)
4633
) |> device
4734

4835
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
4936

5037
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
5138

52-
loader_train, loader_test = get_dataloader()
53-
54-
data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
39+
loader_train, loader_test = get_dataloader(Δt=Δt)
5540

5641
losses = Float32[]
5742
function validate()
@@ -60,16 +45,10 @@ function train(; loss_bounds=[])
6045

6146
push!(losses, validation_loss)
6247
(losses[end] == minimum(losses)) && update_model!(joinpath(@__DIR__, "../model/model.jld2"), m)
63-
64-
isempty(loss_bounds) && return
65-
if validation_loss < loss_bounds[1]
66-
@warn "change η"
67-
opt.os[2].eta /= 2
68-
popfirst!(loss_bounds)
69-
end
7048
end
7149
call_back = Flux.throttle(validate, 10, leading=false, trailing=true)
7250

51+
data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
7352
Flux.@epochs 50 @time(Flux.train!(loss, params(m), data, opt, cb=call_back))
7453
end
7554

example/DoublePendulum/src/data.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,25 @@ function get_data(; i=0, n=-1)
3636
return Float32.(data)
3737
end
3838

39-
function get_dataloader(; i=0, n_train=15733, n_test=2048, Δn=1, batchsize=100)
40-
𝐱 = get_data(i=i, n=-1) # size==(6, 17782)
41-
∇𝐱 = 𝐱[:, (1+Δn):end] - 𝐱[:, 1:(end-Δn)]
42-
𝐱 = reshape(vcat(𝐱[:, 1:(end-Δn)], ∇𝐱), 1, 12, :)
39+
function preprocess(𝐱; Δt=2, ratio=0.9)
40+
𝐱 = reshape(𝐱[:, 1:Δt:end], 1, 1, 6, :)
41+
∇𝐱 = 𝐱[:, :, :, 2:end] - 𝐱[:, :, :, 1:(end-1)]
4342

44-
𝐱_train, 𝐲_train = 𝐱[:, :, 1:(n_train-1)], 𝐱[:, :, 2:n_train]
45-
loader_train = Flux.DataLoader((𝐱_train, 𝐲_train), batchsize=batchsize, shuffle=true)
43+
𝐱 = cat(𝐱[:, :, :, 1:(end-1)], ∇𝐱, dims=2)
44+
45+
n_train, n_test = floor(Int, ratio*size(𝐱)[end]), floor(Int, (1-ratio)*size(𝐱)[end])
4646

47-
𝐱_test, 𝐲_test = 𝐱[:, :, (end-n_test+1):(end-1)], 𝐱[:, :, (end-n_test+2):end]
47+
𝐱_train, 𝐲_train = 𝐱[:, :, :, 1:(n_train-1)], 𝐱[:, :, :, 2:n_train]
48+
𝐱_test, 𝐲_test = 𝐱[:, :, :, (end-n_test+1):(end-1)], 𝐱[:, :, :, (end-n_test+2):end]
49+
50+
return 𝐱_train, 𝐲_train, 𝐱_test, 𝐲_test
51+
end
52+
53+
function get_dataloader(; i=0, Δt=2, ratio=0.9, batchsize=100)
54+
𝐱 = get_data(i=i) # size==(6, :)
55+
𝐱_train, 𝐲_train, 𝐱_test, 𝐲_test = preprocess(𝐱, Δt=Δt, ratio=ratio) # size==(1, 2, 6, :)
56+
57+
loader_train = Flux.DataLoader((𝐱_train, 𝐲_train), batchsize=batchsize, shuffle=true)
4858
loader_test = Flux.DataLoader((𝐱_test, 𝐲_test), batchsize=batchsize, shuffle=false)
4959

5060
return loader_train, loader_test

0 commit comments

Comments
 (0)