Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 4b9faf9

Browse files
committed
train task via 2-D NFO
1 parent 35f9354 commit 4b9faf9

File tree

3 files changed

+61
-36
lines changed

3 files changed

+61
-36
lines changed

example/DoublePendulum/notebook/double_pendulum.jl

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
### A Pluto.jl notebook ###
2-
# v0.15.1
2+
# v0.16.0
33

44
using Markdown
55
using InteractiveUtils
@@ -28,26 +28,35 @@ The data is provided by [IBM](https://developer.ibm.com/exchanges/data/all/doubl
2828
"
2929

3030
# ╔═╡ 5268feee-bda2-4612-9d4c-a1db424a11c7
31-
data, _, _, _ = DoublePendulum.preprocess(
32-
DoublePendulum.get_data(i=20, n=410),
31+
data_x, data_y, _, _ = DoublePendulum.preprocess(
32+
DoublePendulum.get_data(i=20),
3333
ratio=1
3434
);
3535

3636
# ╔═╡ 4d0b08a4-8a54-41fd-997f-ad54d4c984cd
3737
m = DoublePendulum.get_model();
3838

39-
# ╔═╡ 794374ce-6674-481d-8a3b-04db0f32d233
40-
begin
41-
n = 20
42-
43-
ground_truth_data = data[1, :, 1:n]
39+
# ╔═╡ ad6302b2-3d62-4a3f-b8bf-f69bab80c7a4
40+
ground_truth_data = cat(
41+
[data_x[:, :, :, i]
42+
for i in 1:size(data_x, 3):size(data_x)[end]]..., dims=3
43+
)[1, :, :];
4444

45-
inferenced_data = Array{Float32}(undef, 2, 4, n)
46-
inferenced_data[:, :, 1] .= data[:, :, 1]
47-
for i in 2:n
48-
inferenced_data[:, :, i:i] .= m(inferenced_data[:, :, i-1:i-1])
45+
# ╔═╡ 794374ce-6674-481d-8a3b-04db0f32d233
46+
begin
47+
n = 5
48+
inferenced_data = data_x[:, :, :, 1:1]
49+
for i in 1:n
50+
inferenced_data = cat(
51+
inferenced_data,
52+
m(inferenced_data[:, :, :, i:i]),
53+
dims=4
54+
)
4955
end
50-
inferenced_data = inferenced_data[1, :, :]
56+
57+
inferenced_data = cat(
58+
[inferenced_data[:, :, :, i] for i in 1:n]..., dims=3
59+
)[1, :, :]
5160
end;
5261

5362
# ╔═╡ 9c8b3f8a-1b85-4c32-a416-ead51b244b94
@@ -59,22 +68,27 @@ begin
5968
]
6069
xi, yi = [2, 4, 6], [1, 3, 5]
6170

62-
anim = @animate for i in 1:n
71+
anim = @animate for i in 1:size(inferenced_data)[end]
6372
i_data = [0, 0, inferenced_data[:, i]...]
6473
g_data = [0, 0, ground_truth_data[:, i]...]
6574

6675
scatter(
6776
legend=false, ticks=false,
68-
xlim=(-1000, 1000), ylim=(-1000, 1000), size=(400, 350)
77+
xlim=(-1500, 1500), ylim=(-1500, 1500), size=(400, 350)
6978
)
7079
plot!(i_data[xi], i_data[yi], color=:black)
7180
scatter!(i_data[xi], i_data[yi], color=c, markersize=8)
7281
plot!(g_data[xi], g_data[yi], color=:gray)
7382
scatter!(g_data[xi], g_data[yi], color=c, markersize=4)
74-
annotate!(-900, -900, text("t=$i", :left))
83+
84+
if i 30
85+
annotate!(-1400, -1400, text("t=$i", :left, color=:black))
86+
else
87+
annotate!(-1400, -1400, text("t=$i", :left, color=:red))
88+
end
7589
end
7690

77-
gif(anim, fps=5)
91+
gif(anim)
7892
end
7993

8094
# ╔═╡ Cell order:
@@ -84,5 +98,6 @@ end
8498
# ╠═38c9ced5-dcf8-4e03-ac07-7c435687861b
8599
# ╠═5268feee-bda2-4612-9d4c-a1db424a11c7
86100
# ╠═4d0b08a4-8a54-41fd-997f-ad54d4c984cd
101+
# ╠═ad6302b2-3d62-4a3f-b8bf-f69bab80c7a4
87102
# ╠═794374ce-6674-481d-8a3b-04db0f32d233
88103
# ╟─9c8b3f8a-1b85-4c32-a416-ead51b244b94

example/DoublePendulum/src/DoublePendulum.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function update_model!(model_file_path, model)
1515
@warn "model updated!"
1616
end
1717

18-
function train(; Δt=2)
18+
function train(; Δt=1)
1919
if has_cuda()
2020
@info "CUDA is on"
2121
device = gpu
@@ -25,11 +25,13 @@ function train(; Δt=2)
2525
end
2626

2727
m = Chain(
28-
Dense(2, Int(4096/4)),
29-
x -> reshape(x, 1, 64, 64, :),
30-
MarkovNeuralOperator(),
31-
x -> reshape(x, Int(4096/4), 4, :),
32-
Dense(Int(4096/4), 2),
28+
Dense(2, 64),
29+
FourierOperator(64=>64, (4, 16), gelu),
30+
FourierOperator(64=>64, (4, 16), gelu),
31+
FourierOperator(64=>64, (4, 16), gelu),
32+
FourierOperator(64=>64, (4, 16)),
33+
Dense(64, 128, gelu),
34+
Dense(128, 2),
3335
) |> device
3436

3537
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
@@ -49,8 +51,8 @@ function train(; Δt=2)
4951
call_back = Flux.throttle(validate, 10, leading=false, trailing=true)
5052

5153
data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
52-
for e in 1:50
53-
@info "Epoch $e"
54+
for e in 1:20
55+
@info "Epoch $e\n η: $(opt.os[2].eta)"
5456
@time Flux.train!(loss, params(m), data, opt, cb=call_back)
5557
(e%3 == 0) && (opt.os[2].eta /= 2)
5658
end

example/DoublePendulum/src/data.jl

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function get_data(; i=0, n=-1)
3636
return Float32.(data)
3737
end
3838

39-
function preprocess(𝐱; Δt=2, ratio=0.9)
39+
function preprocess(𝐱; Δt=1, nx=30, ny=30, ratio=0.9)
4040
# move red point to (0, 0)
4141
xs_red, ys_red = 𝐱[1, :], 𝐱[2, :]
4242
𝐱[3, :] -= xs_red; 𝐱[5, :] -= xs_red
@@ -46,25 +46,33 @@ function preprocess(𝐱; Δt=2, ratio=0.9)
4646
𝐱 = reshape(𝐱[3:6, 1:Δt:end], 1, 4, :)
4747
# velocity of green and blue points
4848
∇𝐱 = 𝐱[:, :, 2:end] - 𝐱[:, :, 1:(end-1)]
49-
49+
# merge info of pos and velocity
5050
𝐱 = cat(𝐱[:, :, 1:(end-1)], ∇𝐱, dims=1)
5151

52-
n_train, n_test = floor(Int, ratio*size(𝐱)[end]), floor(Int, (1-ratio)*size(𝐱)[end])
52+
# with info of first nx steps to inference next ny steps
53+
n = size(𝐱)[end] - (nx + ny) + 1
54+
𝐱s = Array{Float32}(undef, size(𝐱)[1:2]..., nx, n)
55+
𝐲s = Array{Float32}(undef, size(𝐱)[1:2]..., ny, n)
56+
for i in 1:n
57+
𝐱s[:, :, :, i] .= 𝐱[:, :, i:(i+nx-1)]
58+
𝐲s[:, :, :, i] .= 𝐱[:, :, (i+nx):(i+nx+ny-1)]
59+
end
5360

54-
𝐱_train, 𝐲_train = 𝐱[:, :, 1:(n_train-1)], 𝐱[:, :, 2:n_train]
55-
𝐱_test, 𝐲_test = 𝐱[:, :, (end-n_test+1):(end-1)], 𝐱[:, :, (end-n_test+2):end]
61+
n_train = floor(Int, ratio*n)
62+
𝐱_train, 𝐲_train = 𝐱s[:, :, :, 1:n_train], 𝐲s[:, :, :, 1:n_train]
63+
𝐱_test, 𝐲_test = 𝐱s[:, :, :, (n_train+1):end], 𝐲s[:, :, :, (n_train+1):end]
5664

5765
return 𝐱_train, 𝐲_train, 𝐱_test, 𝐲_test
5866
end
5967

60-
function get_dataloader(; n_file=20, Δt=2, ratio=0.9, batchsize=100)
61-
𝐱_train, 𝐲_train = Array{Float32}(undef, 2, 4, 0), Array{Float32}(undef, 2, 4, 0)
62-
𝐱_test, 𝐲_test = Array{Float32}(undef, 2, 4, 0), Array{Float32}(undef, 2, 4, 0)
68+
function get_dataloader(; n_file=20, Δt=1, nx=30, ny=30, ratio=0.9, batchsize=100)
69+
𝐱_train, 𝐲_train = Array{Float32}(undef, 2, 4, nx, 0), Array{Float32}(undef, 2, 4, ny, 0)
70+
𝐱_test, 𝐲_test = Array{Float32}(undef, 2, 4, nx, 0), Array{Float32}(undef, 2, 4, ny, 0)
6371
for i in 0:(n_file-1)
64-
𝐱_train_i, 𝐲_train_i, 𝐱_test_i, 𝐲_test_i = preprocess(get_data(i=i), Δt=Δt, ratio=ratio)
72+
𝐱_train_i, 𝐲_train_i, 𝐱_test_i, 𝐲_test_i = preprocess(get_data(i=i), Δt=Δt, nx=nx, ny=ny, ratio=ratio)
6573

66-
𝐱_train, 𝐲_train = cat(𝐱_train, 𝐱_train_i, dims=3), cat(𝐲_train, 𝐲_train_i, dims=3)
67-
𝐱_test, 𝐲_test = cat(𝐱_test, 𝐱_test_i, dims=3), cat(𝐲_test, 𝐲_test_i, dims=3)
74+
𝐱_train, 𝐲_train = cat(𝐱_train, 𝐱_train_i, dims=4), cat(𝐲_train, 𝐲_train_i, dims=4)
75+
𝐱_test, 𝐲_test = cat(𝐱_test, 𝐱_test_i, dims=4), cat(𝐲_test, 𝐲_test_i, dims=4)
6876
end
6977

7078
loader_train = Flux.DataLoader((𝐱_train, 𝐲_train), batchsize=batchsize, shuffle=true)

0 commit comments

Comments
 (0)