@@ -3,12 +3,19 @@ module DoublePendulum
3
3
using NeuralOperators
4
4
using Flux
5
5
using CUDA
6
+ using JLD2
6
7
7
8
include (" data.jl" )
8
9
9
10
__init__ () = register_double_pendulum_chaotic ()
10
11
11
- function train (; loss_bounds= [1 , 0.2 , 0.1 , 0.05 , 0.02 ])
12
+ function update_model! (model_file_path, model)
13
+ model = cpu (model)
14
+ jldsave (model_file_path; model)
15
+ @warn " model updated!"
16
+ end
17
+
18
+ function train (; loss_bounds= [1 , 0.3 , 0.1 , 0.05 ])
12
19
if has_cuda ()
13
20
@info " CUDA is on"
14
21
device = gpu
@@ -21,6 +28,7 @@ function train(; loss_bounds=[1, 0.2, 0.1, 0.05, 0.02])
21
28
FourierOperator (6 => 64 , (16 , ), relu),
22
29
FourierOperator (64 => 64 , (16 , ), relu),
23
30
FourierOperator (64 => 64 , (16 , ), relu),
31
+ FourierOperator (64 => 64 , (16 , ), relu),
24
32
FourierOperator (64 => 6 , (16 , )),
25
33
) |> device
26
34
@@ -32,20 +40,24 @@ function train(; loss_bounds=[1, 0.2, 0.1, 0.05, 0.02])
32
40
33
41
data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
34
42
43
+ losses = Float32[]
35
44
function validate ()
36
45
validation_loss = sum (loss (device (𝐱), device (𝐲)) for (𝐱, 𝐲) in loader_test)/ length (loader_test)
37
46
@info " loss: $validation_loss "
38
47
48
+ push! (losses, validation_loss)
49
+ (losses[end ] == minimum (losses)) && update_model! (joinpath (@__DIR__ , " ../model/model.jld2" ), m)
50
+
39
51
isempty (loss_bounds) && return
40
52
if validation_loss < loss_bounds[1 ]
41
53
@warn " change η"
42
54
opt. os[2 ]. eta /= 2
43
55
popfirst! (loss_bounds)
44
56
end
45
57
end
46
-
47
- call_back = Flux . throttle (validate, 1 , leading = false , trailing = true )
48
- Flux. @epochs 300 @time (Flux. train! (loss, params (m), data, opt, cb= call_back))
58
+ call_back = Flux . throttle (validate, 10 , leading = false , trailing = true )
59
+
60
+ Flux. @epochs 50 @time (Flux. train! (loss, params (m), data, opt, cb= call_back))
49
61
end
50
62
51
63
end
0 commit comments