Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 29d55da

Browse files
committed
implement get_dataloader
1 parent 1bbeb43 commit 29d55da

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

example/Burgers/src/Burgers.jl

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,10 @@ function train()
3030
Dense(128, 1),
3131
flatten
3232
) |> device
33+
3334
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
3435

35-
n_train = 1800
36-
n_test = 200
37-
batchsize = 100
38-
𝐱, 𝐲 = get_burgers_data(n=2048)
39-
40-
𝐱_train, 𝐲_train = 𝐱[:, :, 1:n_train], 𝐲[:, 1:n_train]
41-
loader_train = Flux.DataLoader((𝐱_train, 𝐲_train), batchsize=batchsize, shuffle=true)
42-
43-
𝐱_test, 𝐲_test = 𝐱[:, :, end-n_test+1:end], 𝐲[:, end-n_test+1:end]
44-
loader_test = Flux.DataLoader((𝐱_test, 𝐲_test), batchsize=batchsize, shuffle=false)
36+
loader_train, loader_test = get_dataloader()
4537

4638
function validate()
4739
validation_losses = [loss(device(𝐱), device(𝐲)) for (𝐱, 𝐲) in loader_test]

example/Burgers/src/data.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,15 @@ function get_burgers_data(; n=2048, Δsamples=2^3, grid_size=div(2^13, Δsamples
3030

3131
return x_loc_data, y_data
3232
end
33+
34+
function get_dataloader(; n_train=1800, n_test=200, batchsize=100)
35+
𝐱, 𝐲 = get_burgers_data(n=2048)
36+
37+
𝐱_train, 𝐲_train = 𝐱[:, :, 1:n_train], 𝐲[:, 1:n_train]
38+
loader_train = Flux.DataLoader((𝐱_train, 𝐲_train), batchsize=batchsize, shuffle=true)
39+
40+
𝐱_test, 𝐲_test = 𝐱[:, :, end-n_test+1:end], 𝐲[:, end-n_test+1:end]
41+
loader_test = Flux.DataLoader((𝐱_test, 𝐲_test), batchsize=batchsize, shuffle=false)
42+
43+
return loader_train, loader_test
44+
end

0 commit comments

Comments
 (0)