Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 54e2dec

Browse files
committed
Reduce testing data and enable checkpointer
1 parent b0c750c commit 54e2dec

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

example/SuperResolution/src/SuperResolution.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function get_dataloader(; ts::AbstractRange=LinRange(100, 11000, 10000), ratio::
3636
data_train, data_validate = splitobs(shuffleobs((𝐱=data[:, :, :, 1:end-1], 𝐲=data[:, :, :, 2:end])), at=ratio)
3737

3838
data = gen_data(ts, resolution=2)
39-
data_test = (𝐱=data[:, :, :, 1:end-1], 𝐲=data[:, :, :, 2:end])
39+
_, data_test = splitobs(shuffleobs((𝐱=data[:, :, :, 1:end-1], 𝐲=data[:, :, :, 2:end])), at=ratio)
4040

4141
loader_train = DataLoader(data_train, batchsize=batchsize, shuffle=true)
4242
loader_validate = DataLoader(data_validate, batchsize=batchsize, shuffle=false)
@@ -86,7 +86,7 @@ function train(; epochs=50)
8686
learner = Learner(
8787
model, data, optimiser, loss_func,
8888
ToDevice(device, device),
89-
# Checkpointer(joinpath(@__DIR__, "../model/"))
89+
Checkpointer(joinpath(@__DIR__, "../model/"))
9090
)
9191

9292
fit!(learner, epochs)

0 commit comments

Comments
 (0)