Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit f068908

Browse files
committed
Revise training process for SuperRes
1 parent 7daf01f commit f068908

File tree

2 files changed

+126
-22
lines changed

2 files changed

+126
-22
lines changed

example/SuperResolution/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@ name = "SuperResolution"
22
uuid = "a8258e1f-331c-4af2-83e9-878628278453"
33

44
[deps]
5+
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
56
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
67
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
8+
FluxTraining = "7bf95e4d-ca32-48da-9824-f0dc5310474f"
79
GeometricFlux = "7e08b658-56d3-11e9-2997-919d5b31e4ea"
810
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
911
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
1012
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
13+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1114
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
1215
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
1316
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781"
Lines changed: 123 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,134 @@
11
module SuperResolution
22

3-
using NeuralOperators
4-
using Flux
5-
using Flux.Losses: mse
6-
using Flux.Data: DataLoader
7-
using GeometricFlux
8-
using Graphs
9-
using CUDA
10-
using JLD2
11-
using ProgressMeter: Progress, next!
3+
using WaterLily, LinearAlgebra, ProgressMeter, MLUtils
4+
using NeuralOperators, Flux
5+
using CUDA, FluxTraining, BSON
126

13-
include("data.jl")
14-
include("models.jl")
7+
function circle(n, m; Re=250) # copy from [WaterLily](https://github.com/weymouth/WaterLily.jl)
8+
# Set physical parameters
9+
U, R, center = 1., m/8., [m/2, m/2]
10+
ν = U * R / Re
1511

16-
function update_model!(model_file_path, model)
17-
model = cpu(model)
18-
jldsave(model_file_path; model)
19-
@info "model updated!"
12+
body = AutoBody((x,t) -> LinearAlgebra.norm2(x .- center) - R)
13+
Simulation((n+2, m+2), [U, 0.], R; ν, body)
2014
end
2115

22-
function get_model()
23-
f = jldopen(joinpath(@__DIR__, "../model/model.jld2"))
24-
model = f["model"]
25-
close(f)
16+
function gen_data(ts::AbstractRange; resolution=2)
17+
@info "gen data with $(resolution)x resolution... "
18+
p = Progress(length(ts))
19+
20+
n, m = resolution * 3(2^5), resolution * 2^6
21+
circ = circle(n, m)
22+
23+
𝐩s = Array{Float32}(undef, 1, n, m, length(ts))
24+
for (i, t) in enumerate(ts)
25+
sim_step!(circ, t)
26+
𝐩s[1, :, :, i] .= Float32.(circ.flow.p)[2:end-1, 2:end-1]
27+
28+
next!(p)
29+
end
30+
31+
return 𝐩s
32+
end
33+
34+
function get_dataloader(; ts::AbstractRange=LinRange(100, 11000, 10000), ratio::Float64=0.95, batchsize=100)
35+
data = gen_data(ts, resolution=1)
36+
data_train, data_validate = splitobs(shuffleobs((𝐱=data[:, :, :, 1:end-1], 𝐲=data[:, :, :, 2:end])), at=ratio)
37+
38+
data = gen_data(ts, resolution=2)
39+
data_test = (𝐱=data[:, :, :, 1:end-1], 𝐲=data[:, :, :, 2:end])
2640

27-
return model
41+
loader_train = DataLoader(data_train, batchsize=batchsize, shuffle=true)
42+
loader_validate = DataLoader(data_validate, batchsize=batchsize, shuffle=false)
43+
loader_test = DataLoader(data_test, batchsize=batchsize, shuffle=false)
44+
45+
return (training=loader_train, validation=loader_validate, testing=loader_test)
46+
end
47+
48+
struct TestPhase<:FluxTraining.AbstractValidationPhase end
49+
50+
FluxTraining.phasedataiter(::TestPhase) = :testing
51+
52+
function FluxTraining.step!(learner, phase::TestPhase, batch)
53+
xs, ys = batch
54+
FluxTraining.runstep(learner, phase, (xs=xs, ys=ys)) do _, state
55+
state.ŷs = learner.model(state.xs)
56+
state.loss = learner.lossfn(state.ŷs, state.ys)
57+
end
2858
end
2959

30-
loss(m, 𝐱, 𝐲) = mse(m(𝐱), 𝐲)
31-
loss(m, loader::DataLoader, device) = sum(loss(m, 𝐱 |> device, 𝐲 |> device) for (𝐱, 𝐲) in loader)/length(loader)
60+
function fit!(learner, nepochs::Int, (trainiter, validiter, testiter))
61+
for i in 1:nepochs
62+
epoch!(learner, TrainingPhase(), trainiter)
63+
epoch!(learner, ValidationPhase(), validiter)
64+
epoch!(learner, TestPhase(), testiter)
65+
end
66+
end
3267

68+
function fit!(learner, nepochs::Int)
69+
fit!(learner, nepochs, (learner.data.training, learner.data.validation, learner.data.testing))
3370
end
71+
72+
function train(; epochs=50)
73+
if has_cuda()
74+
@info "CUDA is on"
75+
device = gpu
76+
CUDA.allowscalar(false)
77+
else
78+
device = cpu
79+
end
80+
81+
model = MarkovNeuralOperator(ch=(1, 64, 64, 64, 64, 64, 1), modes=(24, 24), σ=gelu)
82+
data = get_dataloader()
83+
optimiser = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
84+
loss_func = l₂loss
85+
86+
learner = Learner(
87+
model, data, optimiser, loss_func,
88+
ToDevice(device, device),
89+
# Checkpointer(joinpath(@__DIR__, "../model/"))
90+
)
91+
92+
fit!(learner, epochs)
93+
94+
return learner
95+
end
96+
97+
function get_model()
98+
model_path = joinpath(@__DIR__, "../model/")
99+
model_file = readdir(model_path)[end]
100+
101+
return BSON.load(joinpath(model_path, model_file), @__MODULE__)[:model]
102+
end
103+
104+
# using NeuralOperators
105+
# using Flux
106+
# using Flux.Losses: mse
107+
# using Flux.Data: DataLoader
108+
# using GeometricFlux
109+
# using Graphs
110+
# using CUDA
111+
# using JLD2
112+
# using ProgressMeter: Progress, next!
113+
114+
# include("data.jl")
115+
# include("models.jl")
116+
117+
# function update_model!(model_file_path, model)
118+
# model = cpu(model)
119+
# jldsave(model_file_path; model)
120+
# @info "model updated!"
121+
# end
122+
123+
# function get_model()
124+
# f = jldopen(joinpath(@__DIR__, "../model/model.jld2"))
125+
# model = f["model"]
126+
# close(f)
127+
128+
# return model
129+
# end
130+
131+
# loss(m, 𝐱, 𝐲) = mse(m(𝐱), 𝐲)
132+
# loss(m, loader::DataLoader, device) = sum(loss(m, 𝐱 |> device, 𝐲 |> device) for (𝐱, 𝐲) in loader)/length(loader)
133+
134+
end # module

0 commit comments

Comments
 (0)