@@ -26,7 +26,7 @@ function train_mno(; cuda=true, η=1f-3, λ=1f-4, epochs=50)
26
26
opt = Flux. Optimiser (WeightDecay (λ), Flux. ADAM (η))
27
27
28
28
# parameters
29
- ps = Flux. params (model )
29
+ ps = Flux. params (m )
30
30
31
31
# training
32
32
min_loss = Inf32
@@ -37,10 +37,10 @@ function train_mno(; cuda=true, η=1f-3, λ=1f-4, epochs=50)
37
37
progress = Progress (length (loader_train))
38
38
39
39
for (𝐱, 𝐲) in loader_train
40
- grad = gradient (() -> loss (model , 𝐱 |> device, 𝐲 |> device), ps)
40
+ grad = gradient (() -> loss (m , 𝐱 |> device, 𝐲 |> device), ps)
41
41
Flux. Optimise. update! (opt, ps, grad)
42
- train_loss = loss (model , loader_train, device)
43
- test_loss = loss (model , loader_test, device)
42
+ train_loss = loss (m , loader_train, device)
43
+ test_loss = loss (m , loader_test, device)
44
44
45
45
# progress meter
46
46
next! (progress; showvalues= [
@@ -49,12 +49,78 @@ function train_mno(; cuda=true, η=1f-3, λ=1f-4, epochs=50)
49
49
])
50
50
51
51
if test_loss ≤ min_loss
52
- update_model! (joinpath (@__DIR__ , " ../model/model.jld2" ), m)
52
+ update_model! (joinpath (@__DIR__ , " ../model/mno.jld2" ), m)
53
+ min_loss = test_loss
53
54
end
54
55
55
56
train_steps += 1
56
57
end
57
58
end
58
59
59
60
return m
60
- end
61
+ end
62
+
63
+ function train_gno (; channel= 64 , cuda= true , η= 1f-3 , λ= 1f-4 , epochs= 50 )
64
+ # GPU config
65
+ if cuda && CUDA. has_cuda ()
66
+ device = gpu
67
+ CUDA. allowscalar (false )
68
+ @info " Training on GPU"
69
+ else
70
+ device = cpu
71
+ @info " Training on CPU"
72
+ end
73
+
74
+ @info " gen data... "
75
+ @time loader_train, loader_test = get_dataloader ()
76
+
77
+ # build model
78
+ g = grid ([12 , 8 ])
79
+ fg = FeaturedGraph (g)
80
+
81
+ m = Chain (
82
+ Dense (1 , 64 ),
83
+ WithGraph (fg, GraphKernel (Dense (2 channel, channel, gelu), channel)),
84
+ WithGraph (fg, GraphKernel (Dense (2 channel, channel, gelu), channel)),
85
+ WithGraph (fg, GraphKernel (Dense (2 channel, channel, gelu), channel)),
86
+ WithGraph (fg, GraphKernel (Dense (2 channel, channel, gelu), channel)),
87
+ Dense (64 , 1 ),
88
+ ) |> device
89
+
90
+ # optimizer
91
+ opt = Flux. Optimiser (WeightDecay (λ), Flux. ADAM (η))
92
+
93
+ # parameters
94
+ ps = Flux. params (m)
95
+
96
+ # training
97
+ min_loss = Inf32
98
+ train_steps = 0
99
+ @info " Start Training, total $(epochs) epochs"
100
+ for epoch = 1 : epochs
101
+ @info " Epoch $(epoch) "
102
+ progress = Progress (length (loader_train))
103
+
104
+ for (𝐱, 𝐲) in loader_train
105
+ grad = gradient (() -> loss (m, 𝐱 |> device, 𝐲 |> device), ps)
106
+ Flux. Optimise. update! (opt, ps, grad)
107
+ train_loss = loss (m, loader_train, device)
108
+ test_loss = loss (m, loader_test, device)
109
+
110
+ # progress meter
111
+ next! (progress; showvalues= [
112
+ (:train_loss , train_loss),
113
+ (:test_loss , test_loss)
114
+ ])
115
+
116
+ if test_loss ≤ min_loss
117
+ update_model! (joinpath (@__DIR__ , " ../model/gno.jld2" ), m)
118
+ min_loss = test_loss
119
+ end
120
+
121
+ train_steps += 1
122
+ end
123
+ end
124
+
125
+ return m
126
+ end
0 commit comments