Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit fd8ecfb

Browse files
committed
Revise training process of DoublePendulum
1 parent 286e0bc commit fd8ecfb

File tree

5 files changed

+103
-133
lines changed

5 files changed

+103
-133
lines changed

example/DoublePendulum/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@ name = "DoublePendulum"
22
uuid = "0c23c1c1-5f41-4617-a685-ac46aae913c3"
33

44
[deps]
5+
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
56
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
67
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
78
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
89
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
910
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
11+
FluxTraining = "7bf95e4d-ca32-48da-9824-f0dc5310474f"
1012
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
13+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1114
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
1215
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
1316
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781"
Lines changed: 92 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,87 @@
11
module DoublePendulum
22

3-
using NeuralOperators
4-
using Flux
5-
using CUDA
6-
using JLD2
3+
using DataDeps, CSV, DataFrames, MLUtils
4+
using NeuralOperators, Flux
5+
using CUDA, FluxTraining, BSON
76

8-
include("data.jl")
7+
function register_double_pendulum_chaotic()
8+
register(DataDep(
9+
"DoublePendulumChaotic",
10+
"""
11+
Dataset was generated on the basis of 21 individual runs of a double pendulum.
12+
Each of the recorded sequences lasted around 40s and consisted of around 17500 frames.
913
10-
__init__() = register_double_pendulum_chaotic()
14+
* `x_red`: Horizontal pixel coordinate of the red point (the central pivot to the first pendulum)
15+
* `y_red`: Vertical pixel coordinate of the red point (the central pivot to the first pendulum)
16+
* `x_green`: Horizontal pixel coordinate of the green point (the first pendulum)
17+
* `y_green`: Vertical pixel coordinate of the green point (the first pendulum)
18+
* `x_blue`: Horizontal pixel coordinate of the blue point (the second pendulum)
19+
* `y_blue`: Vertical pixel coordinate of the blue point (the second pendulum)
20+
21+
Page: https://developer.ibm.com/exchanges/data/all/double-pendulum-chaotic/
22+
""",
23+
"https://dax-cdn.cdn.appdomain.cloud/dax-double-pendulum-chaotic/2.0.1/double-pendulum-chaotic.tar.gz",
24+
"4ca743b4b783094693d313ebedc2e8e53cf29821ee8b20abd99f8fb4c0866f8d",
25+
post_fetch_method=unpack
26+
))
27+
end
28+
29+
function get_data(; i=0, n=-1)
30+
data_path = joinpath(datadep"DoublePendulumChaotic", "original", "dpc_dataset_csv")
31+
df = CSV.read(
32+
joinpath(data_path, "$i.csv"),
33+
DataFrame,
34+
header=[:x_red, :y_red, :x_green, :y_green, :x_blue, :y_blue]
35+
)
36+
data = (n < 0) ? collect(Matrix(df)') : collect(Matrix(df)')[:, 1:n]
37+
38+
return Float32.(data)
39+
end
40+
41+
function preprocess(𝐱; Δt=1, nx=30, ny=30)
42+
# move red point to (0, 0)
43+
xs_red, ys_red = 𝐱[1, :], 𝐱[2, :]
44+
𝐱[3, :] -= xs_red; 𝐱[5, :] -= xs_red
45+
𝐱[4, :] -= ys_red; 𝐱[6, :] -= ys_red
46+
47+
# needs only green and blue points
48+
𝐱 = reshape(𝐱[3:6, 1:Δt:end], 1, 4, :)
49+
# velocity of green and blue points
50+
∇𝐱 = 𝐱[:, :, 2:end] - 𝐱[:, :, 1:(end-1)]
51+
# merge info of pos and velocity
52+
𝐱 = cat(𝐱[:, :, 1:(end-1)], ∇𝐱, dims=1)
53+
54+
# with info of first nx steps to inference next ny steps
55+
n = size(𝐱)[end] - (nx + ny) + 1
56+
𝐱s = Array{Float32}(undef, size(𝐱)[1:2]..., nx, n)
57+
𝐲s = Array{Float32}(undef, size(𝐱)[1:2]..., ny, n)
58+
for i in 1:n
59+
𝐱s[:, :, :, i] .= 𝐱[:, :, i:(i+nx-1)]
60+
𝐲s[:, :, :, i] .= 𝐱[:, :, (i+nx):(i+nx+ny-1)]
61+
end
62+
63+
return 𝐱s, 𝐲s
64+
end
1165

12-
function update_model!(model_file_path, model)
13-
model = cpu(model)
14-
jldsave(model_file_path; model)
15-
@warn "model updated!"
66+
function get_dataloader(; n_file=20, Δt=1, nx=30, ny=30, ratio=0.9, batchsize=100)
67+
𝐱s, 𝐲s = Array{Float32}(undef, 2, 4, nx, 0), Array{Float32}(undef, 2, 4, ny, 0)
68+
for i in 0:(n_file-1)
69+
𝐱s_i, 𝐲s_i = preprocess(get_data(i=i), Δt=Δt, nx=nx, ny=ny)
70+
𝐱s, 𝐲s = cat(𝐱s, 𝐱s_i, dims=4), cat(𝐲s, 𝐲s_i, dims=4)
71+
end
72+
73+
data = shuffleobs((𝐱s, 𝐲s))
74+
data_train, data_test = splitobs(data, at=ratio)
75+
76+
loader_train = Flux.DataLoader(data_train, batchsize=batchsize, shuffle=true)
77+
loader_test = Flux.DataLoader(data_test, batchsize=batchsize, shuffle=false)
78+
79+
return loader_train, loader_test
1680
end
1781

18-
function train(; Δt=1)
82+
__init__() = register_double_pendulum_chaotic()
83+
84+
function train(; Δt=1, epochs=20)
1985
if has_cuda()
2086
@info "CUDA is on"
2187
device = gpu
@@ -24,46 +90,27 @@ function train(; Δt=1)
2490
device = cpu
2591
end
2692

27-
m = Chain(
28-
Dense(2, 64),
29-
OperatorKernel(64=>64, (4, 16), FourierTransform, gelu),
30-
OperatorKernel(64=>64, (4, 16), FourierTransform, gelu),
31-
OperatorKernel(64=>64, (4, 16), FourierTransform, gelu),
32-
OperatorKernel(64=>64, (4, 16), FourierTransform),
33-
Dense(64, 128, gelu),
34-
Dense(128, 2),
35-
) |> device
93+
model = FourierNeuralOperator(ch=(2, 64, 64, 64, 64, 64, 128, 2), modes=(4, 16), σ=gelu)
94+
data = get_dataloader(Δt=Δt)
95+
optimiser = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
96+
loss_func = l₂loss
3697

37-
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
98+
learner = Learner(
99+
model, data, optimiser, loss_func,
100+
ToDevice(device, device),
101+
Checkpointer(joinpath(@__DIR__, "../model/"))
102+
)
38103

39-
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
104+
fit!(learner, epochs)
40105

41-
loader_train, loader_test = get_dataloader(Δt=Δt)
42-
43-
losses = Float32[]
44-
function validate()
45-
validation_loss = sum(loss(device(𝐱), device(𝐲)) for (𝐱, 𝐲) in loader_test)/length(loader_test)
46-
@info "loss: $validation_loss"
47-
48-
push!(losses, validation_loss)
49-
(losses[end] == minimum(losses)) && update_model!(joinpath(@__DIR__, "../model/model.jld2"), m)
50-
end
51-
call_back = Flux.throttle(validate, 10, leading=false, trailing=true)
52-
53-
data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
54-
for e in 1:20
55-
@info "Epoch $e\n η: $(opt.os[2].eta)"
56-
@time Flux.train!(loss, params(m), data, opt, cb=call_back)
57-
(e%3 == 0) && (opt.os[2].eta /= 2)
58-
end
106+
return learner
59107
end
60108

61109
function get_model()
62-
f = jldopen(joinpath(@__DIR__, "../model/model.jld2"))
63-
model = f["model"]
64-
close(f)
110+
model_path = joinpath(@__DIR__, "../model/")
111+
model_file = readdir(model_path)[end]
65112

66-
return model
113+
return BSON.load(joinpath(model_path, model_file), @__MODULE__)[:model]
67114
end
68115

69-
end
116+
end # module

example/DoublePendulum/src/data.jl

Lines changed: 0 additions & 82 deletions
This file was deleted.

example/DoublePendulum/test/data.jl

Lines changed: 0 additions & 5 deletions
This file was deleted.

example/DoublePendulum/test/runtests.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,12 @@ using DoublePendulum
22
using Test
33

44
@testset "DoublePendulum" begin
5-
include("data.jl")
5+
xs = DoublePendulum.get_data(i=0, n=100)
6+
7+
@test size(xs) == (6, 100)
8+
9+
learner = DoublePendulum.train(epochs=5)
10+
loss = learner.cbstate.metricsepoch[ValidationPhase()][:Loss].values[end]
11+
@test loss < 0.05
12+
613
end

0 commit comments

Comments
 (0)