Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 21e04b9

Browse files
committed
revise data
1 parent 6b3bb4b commit 21e04b9

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

example/DoublePendulum/src/data.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@ function get_data(; i=0, n=-1)
3838
return Float32.(data)
3939
end
4040

41-
function get_dataloader(; i=0, n_train=15001, n_test=1501, Δn=1024, batchsize=100)
42-
x = reshape(get_data(; i=i, n=-1), :)
41+
function get_dataloader(; i=0, n_train=15734, n_test=2048, Δn=1024, batchsize=100)
42+
x = reshape(get_data(; i=i, n=-1), :) # size==(6, 17782)
4343
𝐱 = reshape(vcat([x[i:(i+6Δn-1)] for i in 1:6:(length(x)-6(Δn-1))]...), 1, 6, 1024, :)
4444

45-
𝐱_train, 𝐲_train = 𝐱[:, :, :, 1:(n_train-1)], 𝐱[:, :, :, 2:n_train]
45+
𝐱_train, 𝐲_train = 𝐱[:, :, :, 1:(n_train-Δn)], 𝐱[:, :, :, 1+Δn:n_train]
4646
loader_train = Flux.DataLoader((𝐱_train, 𝐲_train), batchsize=batchsize, shuffle=true)
4747

48-
𝐱_test, 𝐲_test = 𝐱[:, :, :, (end-n_test+1):(end-1)], 𝐱[:, :, :, (end-n_test+2):end]
48+
𝐱_test, 𝐲_test = 𝐱[:, :, :, (end-n_test+1):(end-Δn)], 𝐱[:, :, :, (end-n_test+1+Δn):end]
4949
loader_test = Flux.DataLoader((𝐱_test, 𝐲_test), batchsize=batchsize, shuffle=false)
5050

5151
return loader_train, loader_test

0 commit comments

Comments
 (0)