Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit b64ea84

Browse files
committed
refactor SuperResolution example
1 parent 61e944e commit b64ea84

File tree

2 files changed

+66
-39
lines changed

2 files changed

+66
-39
lines changed

example/SuperResolution/src/SuperResolution.jl

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,56 +2,20 @@ module SuperResolution
22

33
using NeuralOperators
44
using Flux
5+
using Flux.Losses: mse
6+
using Flux.Data: DataLoader
57
using CUDA
68
using JLD2
79

810
include("data.jl")
11+
include("models.jl")
912

1013
function update_model!(model_file_path, model)
1114
model = cpu(model)
1215
jldsave(model_file_path; model)
1316
@warn "model updated!"
1417
end
1518

16-
function train()
17-
if has_cuda()
18-
@info "CUDA is on"
19-
device = gpu
20-
CUDA.allowscalar(false)
21-
else
22-
device = cpu
23-
end
24-
25-
m = Chain(
26-
Dense(1, 64),
27-
OperatorKernel(64=>64, (24, 24), gelu),
28-
OperatorKernel(64=>64, (24, 24), gelu),
29-
OperatorKernel(64=>64, (24, 24), gelu),
30-
OperatorKernel(64=>64, (24, 24), gelu),
31-
Dense(64, 1),
32-
) |> device
33-
34-
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
35-
36-
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
37-
38-
@info "gen data... "
39-
@time loader_train, loader_test = get_dataloader()
40-
41-
losses = Float32[]
42-
function validate()
43-
validation_loss = sum(loss(device(𝐱), device(𝐲)) for (𝐱, 𝐲) in loader_test)/length(loader_test)
44-
@info "loss: $validation_loss"
45-
46-
push!(losses, validation_loss)
47-
(losses[end] == minimum(losses)) && update_model!(joinpath(@__DIR__, "../model/model.jld2"), m)
48-
end
49-
call_back = Flux.throttle(validate, 5, leading=false, trailing=true)
50-
51-
data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
52-
Flux.@epochs 50 @time(Flux.train!(loss, params(m), data, opt, cb=call_back))
53-
end
54-
5519
function get_model()
5620
f = jldopen(joinpath(@__DIR__, "../model/model.jld2"))
5721
model = f["model"]
@@ -60,4 +24,7 @@ function get_model()
6024
return model
6125
end
6226

27+
loss(m, 𝐱, 𝐲) = mse(m(𝐱), 𝐲)
28+
loss(m, loader::DataLoader, device) = sum(loss(m, 𝐱 |> device, 𝐲 |> device) for (𝐱, 𝐲) in loader)/length(loader)
29+
6330
end
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
function train_mno(; cuda=true, η=1f-3, λ=1f-4, epochs=50)
2+
# GPU config
3+
if cuda && CUDA.has_cuda()
4+
device = gpu
5+
CUDA.allowscalar(false)
6+
@info "Training on GPU"
7+
else
8+
device = cpu
9+
@info "Training on CPU"
10+
end
11+
12+
@info "gen data... "
13+
@time loader_train, loader_test = get_dataloader()
14+
15+
# build model
16+
m = Chain(
17+
Dense(1, 64),
18+
OperatorKernel(64=>64, (24, 24), gelu),
19+
OperatorKernel(64=>64, (24, 24), gelu),
20+
OperatorKernel(64=>64, (24, 24), gelu),
21+
OperatorKernel(64=>64, (24, 24), gelu),
22+
Dense(64, 1),
23+
) |> device
24+
25+
# optimizer
26+
opt = Flux.Optimiser(WeightDecay(λ), Flux.ADAM(η))
27+
28+
# parameters
29+
ps = Flux.params(model)
30+
31+
# training
32+
min_loss = Inf32
33+
train_steps = 0
34+
@info "Start Training, total $(epochs) epochs"
35+
for epoch = 1:epochs
36+
@info "Epoch $(epoch)"
37+
progress = Progress(length(loader_train))
38+
39+
for (𝐱, 𝐲) in loader_train
40+
grad = gradient(() -> loss(model, 𝐱 |> device, 𝐲 |> device), ps)
41+
Flux.Optimise.update!(opt, ps, grad)
42+
train_loss = loss(model, loader_train, device)
43+
test_loss = loss(model, loader_test, device)
44+
45+
# progress meter
46+
next!(progress; showvalues=[
47+
(:train_loss, train_loss),
48+
(:test_loss, test_loss)
49+
])
50+
51+
if test_loss min_loss
52+
update_model!(joinpath(@__DIR__, "../model/model.jld2"), m)
53+
end
54+
55+
train_steps += 1
56+
end
57+
end
58+
59+
return m
60+
end

0 commit comments

Comments
 (0)