@@ -36,37 +36,37 @@ function get_dataloader(; ts::AbstractRange=LinRange(100, 11000, 10000), ratio::
36
36
data_train, data_validate = splitobs (shuffleobs ((𝐱= data[:, :, :, 1 : end - 1 ], 𝐲= data[:, :, :, 2 : end ])), at= ratio)
37
37
38
38
data = gen_data (ts, resolution= 2 )
39
- _, data_test = splitobs (shuffleobs ((𝐱= data[:, :, :, 1 : end - 1 ], 𝐲= data[:, :, :, 2 : end ])), at= ratio)
39
+ _, data_super_res = splitobs (shuffleobs ((𝐱= data[:, :, :, 1 : end - 1 ], 𝐲= data[:, :, :, 2 : end ])), at= ratio)
40
40
41
41
loader_train = DataLoader (data_train, batchsize= batchsize, shuffle= true )
42
42
loader_validate = DataLoader (data_validate, batchsize= batchsize, shuffle= false )
43
- loader_test = DataLoader (data_test , batchsize= batchsize, shuffle= false )
43
+ loader_super_res = DataLoader (data_super_res , batchsize= batchsize, shuffle= false )
44
44
45
- return (training= loader_train, validation= loader_validate, testing = loader_test )
45
+ return (training= loader_train, validation= loader_validate, super_res = loader_super_res )
46
46
end
47
47
48
- struct TestPhase <: FluxTraining.AbstractValidationPhase end
48
+ struct SuperResPhase <: FluxTraining.AbstractValidationPhase end
49
49
50
- FluxTraining. phasedataiter (:: TestPhase ) = :testing
50
+ FluxTraining. phasedataiter (:: SuperResPhase ) = :super_res
51
51
52
- function FluxTraining. step! (learner, phase:: TestPhase , batch)
52
+ function FluxTraining. step! (learner, phase:: SuperResPhase , batch)
53
53
xs, ys = batch
54
54
FluxTraining. runstep (learner, phase, (xs= xs, ys= ys)) do _, state
55
55
state. ŷs = learner. model (state. xs)
56
56
state. loss = learner. lossfn (state. ŷs, state. ys)
57
57
end
58
58
end
59
59
60
- function fit! (learner, nepochs:: Int , (trainiter, validiter, testiter ))
60
+ function fit! (learner, nepochs:: Int , (loader_train, loader_validate, loader_super_res ))
61
61
for i in 1 : nepochs
62
- epoch! (learner, TrainingPhase (), trainiter )
63
- epoch! (learner, ValidationPhase (), validiter )
64
- epoch! (learner, TestPhase (), testiter )
62
+ epoch! (learner, TrainingPhase (), loader_train )
63
+ epoch! (learner, ValidationPhase (), loader_validate )
64
+ epoch! (learner, SuperResPhase (), loader_super_res )
65
65
end
66
66
end
67
67
68
68
function fit! (learner, nepochs:: Int )
69
- fit! (learner, nepochs, (learner. data. training, learner. data. validation, learner. data. testing ))
69
+ fit! (learner, nepochs, (learner. data. training, learner. data. validation, learner. data. super_res ))
70
70
end
71
71
72
72
function train (; cuda= true , η₀= 1f-3 , λ= 1f-4 , epochs= 50 )
0 commit comments