Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 14ee136

Browse files
committed
implement get_model and refactor
1 parent a2e1a37 commit 14ee136

File tree

4 files changed

+12
-6
lines changed

4 files changed

+12
-6
lines changed

example/DoublePendulum/notebook/data.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ begin
1414
end
1515

1616
# ╔═╡ 5268feee-bda2-4612-9d4c-a1db424a11c7
17-
data = get_double_pendulum_chaotic_data(i=0, n=-1)
17+
data = get_data(i=0, n=-1)
1818

1919
# ╔═╡ 9c8b3f8a-1b85-4c32-a416-ead51b244b94
2020
begin

example/DoublePendulum/src/DoublePendulum.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,12 @@ function train(; loss_bounds=[0.05])
5959
Flux.@epochs 50 @time(Flux.train!(loss, params(m), data, opt, cb=call_back))
6060
end
6161

62+
function get_model()
63+
f = jldopen(joinpath(@__DIR__, "../model/model.jld2"))
64+
model = f["model"]
65+
close(f)
66+
67+
return model
68+
end
69+
6270
end

example/DoublePendulum/src/data.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ using DataDeps
22
using CSV
33
using DataFrames
44

5-
export get_double_pendulum_chaotic_data
6-
75
function register_double_pendulum_chaotic()
86
register(DataDep(
97
"DoublePendulumChaotic",
@@ -26,7 +24,7 @@ function register_double_pendulum_chaotic()
2624
))
2725
end
2826

29-
function get_double_pendulum_chaotic_data(; i=0, n=-1)
27+
function get_data(; i=0, n=-1)
3028
data_path = joinpath(datadep"DoublePendulumChaotic", "original", "dpc_dataset_csv")
3129
df = CSV.read(
3230
joinpath(data_path, "$i.csv"),
@@ -41,7 +39,7 @@ function get_double_pendulum_chaotic_data(; i=0, n=-1)
4139
end
4240

4341
function get_dataloader(; i=0, n_train=15001, n_test=1501, Δn=1024, batchsize=100)
44-
x = reshape(get_double_pendulum_chaotic_data(; i=i, n=-1), :)
42+
x = reshape(get_data(; i=i, n=-1), :)
4543
𝐱 = reshape(vcat([x[i:(i+6Δn-1)] for i in 1:6:(length(x)-6(Δn-1))]...), 1, 6, 1024, :)
4644

4745
𝐱_train, 𝐲_train = 𝐱[:, :, :, 1:(n_train-1)], 𝐱[:, :, :, 2:n_train]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@testset "double pendulum" begin
2-
xs = get_double_pendulum_chaotic_data(i=0, n=100)
2+
xs = DoublePendulum.get_data(i=0, n=100)
33

44
@test size(xs) == (6, 100)
55
end

0 commit comments

Comments
 (0)