Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 5ed1804

Browse files
committed
revise data
1 parent a634b8b commit 5ed1804

File tree

4 files changed

+110
-85
lines changed

4 files changed

+110
-85
lines changed

example/DoublePendulum/notebook/data.jl

Lines changed: 0 additions & 73 deletions
This file was deleted.
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
### A Pluto.jl notebook ###
2+
# v0.15.1
3+
4+
using Markdown
5+
using InteractiveUtils
6+
7+
# ╔═╡ 194baef2-0417-11ec-05ab-4527ef614024
8+
using Pkg; Pkg.develop(path=".."); Pkg.activate("..")
9+
10+
# ╔═╡ 38c9ced5-dcf8-4e03-ac07-7c435687861b
11+
begin
12+
using DoublePendulum
13+
using Plots
14+
end
15+
16+
# ╔═╡ 396b5d7a-a7a4-4f22-a87e-39b405e8d62a
17+
md"
18+
# Double Pendulum
19+
20+
JingYu Ning
21+
"
22+
23+
# ╔═╡ 2a606ecf-acf0-41ad-9290-7569dbb22b5a
24+
md"
25+
The data is provided by [IBM](https://developer.ibm.com/exchanges/data/all/double-pendulum-chaotic/)
26+
27+
> In this dataset, videos of the double pendulum were taken using a high-speed Phantom Miro EX2 camera. To make the extraction of the arm positions easier, a matte black background was used, and the three datums were marked with red, green and blue fiducial markers. The camera was placed at 2 meters from the pendulum, with the axis of the objective aligned with the first pendulum datum. The pendulum was launched by hand, and the camera was motion triggered. The dataset was generated on the basis of 21 individual runs of the pendulum. Each of the recorded sequences lasted around 40s and consisted of around 17500 frames.
28+
"
29+
30+
# ╔═╡ 5268feee-bda2-4612-9d4c-a1db424a11c7
31+
data, _, _, _ = DoublePendulum.preprocess(
32+
DoublePendulum.get_data(i=10, n=410),
33+
ratio=1
34+
);
35+
36+
# ╔═╡ 4d0b08a4-8a54-41fd-997f-ad54d4c984cd
37+
m = DoublePendulum.get_model();
38+
39+
# ╔═╡ 794374ce-6674-481d-8a3b-04db0f32d233
40+
begin
41+
n = 20
42+
43+
ground_truth_data = data[1, :, 1:n]
44+
45+
inferenced_data = Array{Float32}(undef, 2, 4, n)
46+
inferenced_data[:, :, 1] .= data[:, :, 1]
47+
for i in 2:n
48+
inferenced_data[:, :, i:i] .= m(inferenced_data[:, :, i-1:i-1])
49+
end
50+
inferenced_data = inferenced_data[1, :, :]
51+
end;
52+
53+
# ╔═╡ 9c8b3f8a-1b85-4c32-a416-ead51b244b94
54+
begin
55+
c = [
56+
RGB([239, 71, 111]/255...),
57+
RGB([6, 214, 160]/255...),
58+
RGB([17, 138, 178]/255...)
59+
]
60+
xi, yi = [2, 4, 6], [1, 3, 5]
61+
62+
anim = @animate for i in 1:n
63+
i_data = [0, 0, inferenced_data[:, i]...]
64+
g_data = [0, 0, ground_truth_data[:, i]...]
65+
66+
scatter(
67+
legend=false, ticks=false,
68+
xlim=(-1000, 1000), ylim=(-1000, 1000), size=(400, 350)
69+
)
70+
plot!(i_data[xi], i_data[yi], color=:black)
71+
scatter!(i_data[xi], i_data[yi], color=c, markersize=8)
72+
plot!(g_data[xi], g_data[yi], color=:gray)
73+
scatter!(g_data[xi], g_data[yi], color=c, markersize=4)
74+
annotate!(-900, -900, text("t=$i", :left))
75+
end
76+
77+
gif(anim, fps=5)
78+
end
79+
80+
# ╔═╡ Cell order:
81+
# ╟─396b5d7a-a7a4-4f22-a87e-39b405e8d62a
82+
# ╟─2a606ecf-acf0-41ad-9290-7569dbb22b5a
83+
# ╟─194baef2-0417-11ec-05ab-4527ef614024
84+
# ╠═38c9ced5-dcf8-4e03-ac07-7c435687861b
85+
# ╠═5268feee-bda2-4612-9d4c-a1db424a11c7
86+
# ╠═4d0b08a4-8a54-41fd-997f-ad54d4c984cd
87+
# ╠═794374ce-6674-481d-8a3b-04db0f32d233
88+
# ╟─9c8b3f8a-1b85-4c32-a416-ead51b244b94

example/DoublePendulum/src/DoublePendulum.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ function train(; Δt=2)
2525
end
2626

2727
m = Chain(
28-
Dense(1, 350), # (1, 2, 6, :) -> (350, 2, 6, :)
29-
x -> reshape(x, 1, 60, 70, :), # (350, 2, 6, :) -> (1, 60, 70, :)
28+
Dense(2, Int(4096/4)),
29+
x -> reshape(x, 1, 64, 64, :),
3030
MarkovNeuralOperator(),
31-
x -> reshape(x, 350, 2, 6, :), # (1, 60, 70, :) -> (350, 2, 6, :)
32-
Dense(350, 1), # (350, 2, 6, :) -> (1, 2, 6, :)
31+
x -> reshape(x, Int(4096/4), 4, :),
32+
Dense(Int(4096/4), 2),
3333
) |> device
3434

3535
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]

example/DoublePendulum/src/data.jl

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,32 @@ function get_data(; i=0, n=-1)
3737
end
3838

3939
function preprocess(𝐱; Δt=2, ratio=0.9)
40-
𝐱 = reshape(𝐱[:, 1:Δt:end], 1, 1, 6, :)
41-
∇𝐱 = 𝐱[:, :, :, 2:end] - 𝐱[:, :, :, 1:(end-1)]
40+
xs_red, ys_red = 𝐱[1, :], 𝐱[2, :]
41+
𝐱[1, :] -= xs_red; 𝐱[3, :] -= xs_red; 𝐱[5, :] -= xs_red
42+
𝐱[2, :] -= ys_red; 𝐱[4, :] -= ys_red; 𝐱[6, :] -= ys_red
4243

43-
𝐱 = cat(𝐱[:, :, :, 1:(end-1)], ∇𝐱, dims=2)
44+
𝐱 = reshape(𝐱[3:6, 1:Δt:end], 1, 4, :)
45+
∇𝐱 = 𝐱[:, :, 2:end] - 𝐱[:, :, 1:(end-1)]
46+
47+
𝐱 = cat(𝐱[:, :, 1:(end-1)], ∇𝐱, dims=1)
4448

4549
n_train, n_test = floor(Int, ratio*size(𝐱)[end]), floor(Int, (1-ratio)*size(𝐱)[end])
4650

47-
𝐱_train, 𝐲_train = 𝐱[:, :, :, 1:(n_train-1)], 𝐱[:, :, :, 2:n_train]
48-
𝐱_test, 𝐲_test = 𝐱[:, :, :, (end-n_test+1):(end-1)], 𝐱[:, :, :, (end-n_test+2):end]
51+
𝐱_train, 𝐲_train = 𝐱[:, :, 1:(n_train-1)], 𝐱[:, :, 2:n_train]
52+
𝐱_test, 𝐲_test = 𝐱[:, :, (end-n_test+1):(end-1)], 𝐱[:, :, (end-n_test+2):end]
4953

5054
return 𝐱_train, 𝐲_train, 𝐱_test, 𝐲_test
5155
end
5256

53-
function get_dataloader(; i=0, Δt=2, ratio=0.9, batchsize=100)
54-
𝐱 = get_data(i=i) # size==(6, :)
55-
𝐱_train, 𝐲_train, 𝐱_test, 𝐲_test = preprocess(𝐱, Δt=Δt, ratio=ratio) # size==(1, 2, 6, :)
57+
function get_dataloader(; n_file=10, Δt=2, ratio=0.9, batchsize=100)
58+
𝐱_train, 𝐲_train = Array{Float32}(undef, 2, 4, 0), Array{Float32}(undef, 2, 4, 0)
59+
𝐱_test, 𝐲_test = Array{Float32}(undef, 2, 4, 0), Array{Float32}(undef, 2, 4, 0)
60+
for i in 0:(n_file-1)
61+
𝐱_train_i, 𝐲_train_i, 𝐱_test_i, 𝐲_test_i = preprocess(get_data(i=i), Δt=Δt, ratio=ratio)
62+
63+
𝐱_train, 𝐲_train = cat(𝐱_train, 𝐱_train_i, dims=3), cat(𝐲_train, 𝐲_train_i, dims=3)
64+
𝐱_test, 𝐲_test = cat(𝐱_test, 𝐱_test_i, dims=3), cat(𝐲_test, 𝐲_test_i, dims=3)
65+
end
5666

5767
loader_train = Flux.DataLoader((𝐱_train, 𝐲_train), batchsize=batchsize, shuffle=true)
5868
loader_test = Flux.DataLoader((𝐱_test, 𝐲_test), batchsize=batchsize, shuffle=false)

0 commit comments

Comments
 (0)