Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit dae78a1

Browse files
committed
Refactor
1 parent 6c3e869 commit dae78a1

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

example/SuperResolution/src/SuperResolution.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,37 +36,37 @@ function get_dataloader(; ts::AbstractRange=LinRange(100, 11000, 10000), ratio::
3636
data_train, data_validate = splitobs(shuffleobs((𝐱=data[:, :, :, 1:end-1], 𝐲=data[:, :, :, 2:end])), at=ratio)
3737

3838
data = gen_data(ts, resolution=2)
39-
_, data_test = splitobs(shuffleobs((𝐱=data[:, :, :, 1:end-1], 𝐲=data[:, :, :, 2:end])), at=ratio)
39+
_, data_super_res = splitobs(shuffleobs((𝐱=data[:, :, :, 1:end-1], 𝐲=data[:, :, :, 2:end])), at=ratio)
4040

4141
loader_train = DataLoader(data_train, batchsize=batchsize, shuffle=true)
4242
loader_validate = DataLoader(data_validate, batchsize=batchsize, shuffle=false)
43-
loader_test = DataLoader(data_test, batchsize=batchsize, shuffle=false)
43+
loader_super_res = DataLoader(data_super_res, batchsize=batchsize, shuffle=false)
4444

45-
return (training=loader_train, validation=loader_validate, testing=loader_test)
45+
return (training=loader_train, validation=loader_validate, super_res=loader_super_res)
4646
end
4747

48-
struct TestPhase<:FluxTraining.AbstractValidationPhase end
48+
struct SuperResPhase<:FluxTraining.AbstractValidationPhase end
4949

50-
FluxTraining.phasedataiter(::TestPhase) = :testing
50+
FluxTraining.phasedataiter(::SuperResPhase) = :super_res
5151

52-
function FluxTraining.step!(learner, phase::TestPhase, batch)
52+
function FluxTraining.step!(learner, phase::SuperResPhase, batch)
5353
xs, ys = batch
5454
FluxTraining.runstep(learner, phase, (xs=xs, ys=ys)) do _, state
5555
state.ŷs = learner.model(state.xs)
5656
state.loss = learner.lossfn(state.ŷs, state.ys)
5757
end
5858
end
5959

60-
function fit!(learner, nepochs::Int, (trainiter, validiter, testiter))
60+
function fit!(learner, nepochs::Int, (loader_train, loader_validate, loader_super_res))
6161
for i in 1:nepochs
62-
epoch!(learner, TrainingPhase(), trainiter)
63-
epoch!(learner, ValidationPhase(), validiter)
64-
epoch!(learner, TestPhase(), testiter)
62+
epoch!(learner, TrainingPhase(), loader_train)
63+
epoch!(learner, ValidationPhase(), loader_validate)
64+
epoch!(learner, SuperResPhase(), loader_super_res)
6565
end
6666
end
6767

6868
function fit!(learner, nepochs::Int)
69-
fit!(learner, nepochs, (learner.data.training, learner.data.validation, learner.data.testing))
69+
fit!(learner, nepochs, (learner.data.training, learner.data.validation, learner.data.super_res))
7070
end
7171

7272
function train(; cuda=true, η₀=1f-3, λ=1f-4, epochs=50)

0 commit comments

Comments
 (0)