Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 35f9354

Browse files
committed
use all data
1 parent 5ed1804 commit 35f9354

File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed

example/DoublePendulum/notebook/double_pendulum.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ The data is provided by [IBM](https://developer.ibm.com/exchanges/data/all/doubl
2929

3030
# ╔═╡ 5268feee-bda2-4612-9d4c-a1db424a11c7
3131
data, _, _, _ = DoublePendulum.preprocess(
32-
DoublePendulum.get_data(i=10, n=410),
32+
DoublePendulum.get_data(i=20, n=410),
3333
ratio=1
3434
);
3535

@@ -39,9 +39,9 @@ m = DoublePendulum.get_model();
3939
# ╔═╡ 794374ce-6674-481d-8a3b-04db0f32d233
4040
begin
4141
n = 20
42-
42+
4343
ground_truth_data = data[1, :, 1:n]
44-
44+
4545
inferenced_data = Array{Float32}(undef, 2, 4, n)
4646
inferenced_data[:, :, 1] .= data[:, :, 1]
4747
for i in 2:n
@@ -53,18 +53,18 @@ end;
5353
# ╔═╡ 9c8b3f8a-1b85-4c32-a416-ead51b244b94
5454
begin
5555
c = [
56-
RGB([239, 71, 111]/255...),
57-
RGB([6, 214, 160]/255...),
56+
RGB([239, 71, 111]/255...),
57+
RGB([6, 214, 160]/255...),
5858
RGB([17, 138, 178]/255...)
5959
]
6060
xi, yi = [2, 4, 6], [1, 3, 5]
61-
61+
6262
anim = @animate for i in 1:n
6363
i_data = [0, 0, inferenced_data[:, i]...]
6464
g_data = [0, 0, ground_truth_data[:, i]...]
65-
65+
6666
scatter(
67-
legend=false, ticks=false,
67+
legend=false, ticks=false,
6868
xlim=(-1000, 1000), ylim=(-1000, 1000), size=(400, 350)
6969
)
7070
plot!(i_data[xi], i_data[yi], color=:black)

example/DoublePendulum/src/DoublePendulum.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@ function train(; Δt=2)
4949
call_back = Flux.throttle(validate, 10, leading=false, trailing=true)
5050

5151
data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
52-
Flux.@epochs 50 @time(Flux.train!(loss, params(m), data, opt, cb=call_back))
52+
for e in 1:50
53+
@info "Epoch $e"
54+
@time Flux.train!(loss, params(m), data, opt, cb=call_back)
55+
(e%3 == 0) && (opt.os[2].eta /= 2)
56+
end
5357
end
5458

5559
function get_model()

example/DoublePendulum/src/data.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,14 @@ function get_data(; i=0, n=-1)
3737
end
3838

3939
function preprocess(𝐱; Δt=2, ratio=0.9)
40+
# move red point to (0, 0)
4041
xs_red, ys_red = 𝐱[1, :], 𝐱[2, :]
41-
𝐱[1, :] -= xs_red; 𝐱[3, :] -= xs_red; 𝐱[5, :] -= xs_red
42-
𝐱[2, :] -= ys_red; 𝐱[4, :] -= ys_red; 𝐱[6, :] -= ys_red
42+
𝐱[3, :] -= xs_red; 𝐱[5, :] -= xs_red
43+
𝐱[4, :] -= ys_red; 𝐱[6, :] -= ys_red
4344

45+
# needs only green and blue points
4446
𝐱 = reshape(𝐱[3:6, 1:Δt:end], 1, 4, :)
47+
# velocity of green and blue points
4548
∇𝐱 = 𝐱[:, :, 2:end] - 𝐱[:, :, 1:(end-1)]
4649

4750
𝐱 = cat(𝐱[:, :, 1:(end-1)], ∇𝐱, dims=1)
@@ -54,7 +57,7 @@ function preprocess(𝐱; Δt=2, ratio=0.9)
5457
return 𝐱_train, 𝐲_train, 𝐱_test, 𝐲_test
5558
end
5659

57-
function get_dataloader(; n_file=10, Δt=2, ratio=0.9, batchsize=100)
60+
function get_dataloader(; n_file=20, Δt=2, ratio=0.9, batchsize=100)
5861
𝐱_train, 𝐲_train = Array{Float32}(undef, 2, 4, 0), Array{Float32}(undef, 2, 4, 0)
5962
𝐱_test, 𝐲_test = Array{Float32}(undef, 2, 4, 0), Array{Float32}(undef, 2, 4, 0)
6063
for i in 0:(n_file-1)

0 commit comments

Comments
 (0)