@@ -15,7 +15,7 @@ function update_model!(model_file_path, model)
15
15
@warn " model updated!"
16
16
end
17
17
18
- function train (; loss_bounds= [0.05 ])
18
+ function train (; loss_bounds= [])
19
19
if has_cuda ()
20
20
@info " CUDA is on"
21
21
device = gpu
@@ -25,15 +25,29 @@ function train(; loss_bounds=[0.05])
25
25
end
26
26
27
27
m = Chain (
28
- FourierOperator (1 => 64 , (6 , 64 , ), relu),
29
- FourierOperator (64 => 64 , (6 , 64 , ), relu),
30
- FourierOperator (64 => 64 , (6 , 64 , ), relu),
31
- FourierOperator (64 => 1 , (6 , 64 , )),
28
+ Dense (1 , 64 , gelu),
29
+ FourierOperator (64 => 64 , (12 , ), gelu),
30
+ FourierOperator (64 => 64 , (12 , ), gelu),
31
+ FourierOperator (64 => 64 , (12 , ), gelu),
32
+ FourierOperator (64 => 64 , (12 , ), gelu),
33
+ FourierOperator (64 => 64 , (12 , ), gelu),
34
+ FourierOperator (64 => 64 , (12 , ), gelu),
35
+ FourierOperator (64 => 64 , (12 , ), gelu),
36
+ FourierOperator (64 => 64 , (12 , ), gelu),
37
+ FourierOperator (64 => 64 , (12 , ), gelu),
38
+ FourierOperator (64 => 64 , (12 , ), gelu),
39
+ FourierOperator (64 => 64 , (12 , ), gelu),
40
+ FourierOperator (64 => 64 , (12 , ), gelu),
41
+ FourierOperator (64 => 64 , (12 , ), gelu),
42
+ FourierOperator (64 => 64 , (12 , ), gelu),
43
+ FourierOperator (64 => 64 , (12 , ), gelu),
44
+ FourierOperator (64 => 64 , (12 , )),
45
+ Dense (64 , 1 )
32
46
) |> device
33
47
34
48
loss (𝐱, 𝐲) = sum (abs2, 𝐲 .- m (𝐱)) / size (𝐱)[end ]
35
49
36
- opt = Flux. Optimiser (WeightDecay (1f-4 ), Flux. ADAM (1f-2 ))
50
+ opt = Flux. Optimiser (WeightDecay (1f-4 ), Flux. ADAM (1f-3 ))
37
51
38
52
loader_train, loader_test = get_dataloader ()
39
53
0 commit comments