@@ -18,25 +18,36 @@ function train()
18
18
end
19
19
20
20
m = Chain (
21
- FourierOperator (6 => 6 , (16 , ), gelu),
22
- FourierOperator (6 => 6 , (16 , ), gelu),
23
- FourierOperator (6 => 6 , (16 , ), gelu),
21
+ FourierOperator (6 => 6 , (16 , ), relu),
22
+ FourierOperator (6 => 6 , (16 , ), relu),
23
+ FourierOperator (6 => 6 , (16 , ), relu),
24
+ FourierOperator (6 => 6 , (16 , ), relu),
24
25
FourierOperator (6 => 6 , (16 , )),
25
26
) |> device
26
27
27
28
loss (𝐱, 𝐲) = sum (abs2, 𝐲 .- m (𝐱)) / size (𝐱)[end ]
28
29
30
+ opt = Flux. Optimiser (WeightDecay (1f-4 ), Flux. ADAM (1f-3 ))
31
+
29
32
loader_train, loader_test = get_dataloader ()
30
33
34
+ data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
35
+
36
+ loss_bounds = [0.3 , 0.05 , 0.01 ]
31
37
function validate ()
32
- validation_losses = [loss (device (𝐱), device (𝐲)) for (𝐱, 𝐲) in loader_test]
33
- @info " loss: $(sum (validation_losses)/ length (loader_test)) "
38
+ validation_loss = sum (loss (device (𝐱), device (𝐲)) for (𝐱, 𝐲) in loader_test)/ length (loader_test)
39
+ @info " loss: $validation_loss "
40
+
41
+ isempty (loss_bounds) && return
42
+ if validation_loss < loss_bounds[1 ]
43
+ @warn " change η"
44
+ opt. os[2 ]. eta /= 2
45
+ popfirst! (loss_bounds)
46
+ end
34
47
end
35
48
36
- data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
37
- opt = Flux. Optimiser (WeightDecay (1f-4 ), Flux. ADAM (1f-4 ))
38
- call_back = Flux. throttle (validate, 0.5 , leading= false , trailing= true )
39
- Flux. @epochs 500 @time (Flux. train! (loss, params (m), data, opt, cb= call_back))
49
+ call_back = Flux. throttle (validate, 1 , leading= false , trailing= true )
50
+ Flux. @epochs 300 @time (Flux. train! (loss, params (m), data, opt, cb= call_back))
40
51
end
41
52
42
53
end
0 commit comments