@@ -2,56 +2,20 @@ module SuperResolution
2
2
3
3
using NeuralOperators
4
4
using Flux
5
+ using Flux. Losses: mse
6
+ using Flux. Data: DataLoader
5
7
using CUDA
6
8
using JLD2
7
9
8
10
include (" data.jl" )
11
+ include (" models.jl" )
9
12
10
13
function update_model! (model_file_path, model)
11
14
model = cpu (model)
12
15
jldsave (model_file_path; model)
13
16
@warn " model updated!"
14
17
end
15
18
16
- function train ()
17
- if has_cuda ()
18
- @info " CUDA is on"
19
- device = gpu
20
- CUDA. allowscalar (false )
21
- else
22
- device = cpu
23
- end
24
-
25
- m = Chain (
26
- Dense (1 , 64 ),
27
- OperatorKernel (64 => 64 , (24 , 24 ), gelu),
28
- OperatorKernel (64 => 64 , (24 , 24 ), gelu),
29
- OperatorKernel (64 => 64 , (24 , 24 ), gelu),
30
- OperatorKernel (64 => 64 , (24 , 24 ), gelu),
31
- Dense (64 , 1 ),
32
- ) |> device
33
-
34
- loss (𝐱, 𝐲) = sum (abs2, 𝐲 .- m (𝐱)) / size (𝐱)[end ]
35
-
36
- opt = Flux. Optimiser (WeightDecay (1f-4 ), Flux. ADAM (1f-3 ))
37
-
38
- @info " gen data... "
39
- @time loader_train, loader_test = get_dataloader ()
40
-
41
- losses = Float32[]
42
- function validate ()
43
- validation_loss = sum (loss (device (𝐱), device (𝐲)) for (𝐱, 𝐲) in loader_test)/ length (loader_test)
44
- @info " loss: $validation_loss "
45
-
46
- push! (losses, validation_loss)
47
- (losses[end ] == minimum (losses)) && update_model! (joinpath (@__DIR__ , " ../model/model.jld2" ), m)
48
- end
49
- call_back = Flux. throttle (validate, 5 , leading= false , trailing= true )
50
-
51
- data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
52
- Flux. @epochs 50 @time (Flux. train! (loss, params (m), data, opt, cb= call_back))
53
- end
54
-
55
19
function get_model ()
56
20
f = jldopen (joinpath (@__DIR__ , " ../model/model.jld2" ))
57
21
model = f[" model" ]
@@ -60,4 +24,7 @@ function get_model()
60
24
return model
61
25
end
62
26
27
+ loss (m, 𝐱, 𝐲) = mse (m (𝐱), 𝐲)
28
+ loss (m, loader:: DataLoader , device) = sum (loss (m, 𝐱 |> device, 𝐲 |> device) for (𝐱, 𝐲) in loader)/ length (loader)
29
+
63
30
end
0 commit comments