Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 79009ca

Browse files
authored
Merge pull request #24 from foldfelis/sr
implement example for super resolution with MNO
2 parents 05873a1 + 1286a8a commit 79009ca

File tree

13 files changed

+333
-1
lines changed

13 files changed

+333
-1
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ PDE training examples are provided in `example` folder.
9191

9292
[Time dependent Navier-Stokes equation](example/FlowOverCircle)
9393

94+
### Super Resolution with MNO
95+
96+
[Super resolution on time dependent Navier-Stokes equation](example/SuperResolution)
97+
9498
## Roadmap
9599

96100
- [x] `FourierOperator` layer

docs/src/assets/notebook/super_resolution_mno.jl.html

Lines changed: 75 additions & 0 deletions
Large diffs are not rendered by default.

example/FlowOverCircle/src/data.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function gen_data(ts::AbstractRange)
2222
𝐩s = Array{Float32}(undef, 1, n, m, length(ts))
2323
for (i, t) in enumerate(ts)
2424
sim_step!(circ, t)
25-
𝐩s[:, :, :, i] = Float32.(circ.flow.p)[2:end-1, 2:end-1]
25+
𝐩s[1, :, :, i] .= Float32.(circ.flow.p)[2:end-1, 2:end-1]
2626
end
2727

2828
return 𝐩s
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
name = "SuperResolution"
2+
uuid = "a8258e1f-331c-4af2-83e9-878628278453"
3+
4+
[deps]
5+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
6+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
7+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
8+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9+
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
10+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
11+
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781"
12+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
13+
WaterLily = "ed894a53-35f9-47f1-b17f-85db9237eebd"
14+
15+
[extras]
16+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
17+
18+
[targets]
19+
test = ["Test"]

example/SuperResolution/README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Super Resolution
2+
3+
The time dependent Navier-Stokes equation is learned by the `MarkovNeuralOperator` with only one time step information.
4+
The result of this example can be found [here](https://foldfelis.github.io/NeuralOperators.jl/dev/assets/notebook/super_resolution_mno.jl.html).
5+
6+
Apart from just training a MNO, here, we train the model with lower resolution (96x64) and inference result with higher resolution (192x128).
7+
8+
| **Ground Truth** | **Inferenced** |
9+
|:----------------:|:--------------:|
10+
| ![](gallery/ans.gif) | ![](gallery/inferenced.gif) |
11+
12+
Change directory to `example/SuperResolution` and use following commend to train model:
13+
14+
```julia
15+
$ julia --proj
16+
17+
julia> using SuperResolution; SuperResolution.train()
18+
```
408 KB
Loading
514 KB
Loading

example/SuperResolution/model/.gitkeep

Whitespace-only changes.
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
### A Pluto.jl notebook ###
2+
# v0.16.1
3+
4+
using Markdown
5+
using InteractiveUtils
6+
7+
# ╔═╡ 194baef2-0417-11ec-05ab-4527ef614024
8+
using Pkg; Pkg.develop(path=".."); Pkg.activate("..")
9+
10+
# ╔═╡ 38c9ced5-dcf8-4e03-ac07-7c435687861b
11+
using SuperResolution, Plots
12+
13+
# ╔═╡ 50ce80a3-a1e8-4ba9-a032-dad315bcb432
14+
md"
15+
# Super Resolution with MNO
16+
17+
JingYu Ning
18+
"
19+
20+
# ╔═╡ 59769504-ebd5-4c6f-981f-d03826d8e34a
21+
md"
22+
This demo trains a Markov neural operator (MNO) introduced by [Zongyi Li *et al.*](https://arxiv.org/abs/2106.06898) with only one time step information. Then composed the operator to a Markov chain and inference the Navier-Stokes equations."
23+
24+
# ╔═╡ 823b3547-6723-43cf-85e6-cc6eb44efea1
25+
md"
26+
## Generate data
27+
"
28+
29+
# ╔═╡ 5268feee-bda2-4612-9d4c-a1db424a11c7
30+
begin
31+
n = 10
32+
data = SuperResolution.gen_data(LinRange(100, 100+n-1, n))
33+
end;
34+
35+
# ╔═╡ 5531bba6-94bd-4c99-be8c-43fe19ad8a60
36+
md"
37+
## Training
38+
"
39+
40+
# ╔═╡ 74fc528f-ccd4-4670-9b17-dbfa7a1c74b6
41+
md"
42+
Apart from just training a MNO, here, we train the model with lower resolution (96x64) and inference result with higher resolution (192x128).
43+
"
44+
45+
# ╔═╡ f6d1ce85-a195-4ab1-bd3a-dbd4b0d1fcca
46+
begin
47+
anim = @animate for i in 1:size(data)[end]
48+
heatmap(data[1, 1:2:end, 1:2:end, i]', color=:coolwarm, clim=(-1.5, 1.5))
49+
scatter!(
50+
[size(data, 3)÷4-1], [size(data, 3)÷4-1],
51+
markersize=45, color=:black, legend=false, ticks=false
52+
)
53+
annotate!(5, 5, text("i=$i", :left))
54+
end
55+
gif(anim, fps=2)
56+
end
57+
58+
# ╔═╡ 55058635-c7e9-4ee3-81c2-0153e84f4c8e
59+
md"
60+
## Inference
61+
62+
Use the first data generated above as the initial state, and apply the operator recurrently.
63+
"
64+
65+
# ╔═╡ fbc287b8-f232-4350-9948-2091908e5a30
66+
begin
67+
m = SuperResolution.get_model()
68+
69+
states = Array{Float32}(undef, size(data))
70+
states[:, :, :, 1] .= view(data, :, :, :, 1)
71+
for i in 2:size(data)[end]
72+
states[:, :, :, i:i] .= m(view(states, :, :, :, i-1:i-1))
73+
end
74+
end
75+
76+
# ╔═╡ a0b5e94c-a839-4cc0-a325-1a4ac39fafbc
77+
begin
78+
anim_model = @animate for i in 1:size(states)[end]
79+
heatmap(states[1, :, :, i]', color=:coolwarm, clim=(-1.5, 1.5))
80+
scatter!(
81+
[size(data, 3)÷2-1], [size(data, 3)÷2-1],
82+
markersize=45, color=:black, legend=false, ticks=false
83+
)
84+
annotate!(5, 5, text("i=$i", :left))
85+
end
86+
gif(anim_model, fps=2)
87+
end
88+
89+
# ╔═╡ Cell order:
90+
# ╟─50ce80a3-a1e8-4ba9-a032-dad315bcb432
91+
# ╟─59769504-ebd5-4c6f-981f-d03826d8e34a
92+
# ╠═194baef2-0417-11ec-05ab-4527ef614024
93+
# ╠═38c9ced5-dcf8-4e03-ac07-7c435687861b
94+
# ╟─823b3547-6723-43cf-85e6-cc6eb44efea1
95+
# ╠═5268feee-bda2-4612-9d4c-a1db424a11c7
96+
# ╟─5531bba6-94bd-4c99-be8c-43fe19ad8a60
97+
# ╟─74fc528f-ccd4-4670-9b17-dbfa7a1c74b6
98+
# ╠═f6d1ce85-a195-4ab1-bd3a-dbd4b0d1fcca
99+
# ╟─55058635-c7e9-4ee3-81c2-0153e84f4c8e
100+
# ╠═fbc287b8-f232-4350-9948-2091908e5a30
101+
# ╟─a0b5e94c-a839-4cc0-a325-1a4ac39fafbc
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
module SuperResolution
2+
3+
using NeuralOperators
4+
using Flux
5+
using CUDA
6+
using JLD2
7+
8+
include("data.jl")
9+
10+
function update_model!(model_file_path, model)
11+
model = cpu(model)
12+
jldsave(model_file_path; model)
13+
@warn "model updated!"
14+
end
15+
16+
function train()
17+
if has_cuda()
18+
@info "CUDA is on"
19+
device = gpu
20+
CUDA.allowscalar(false)
21+
else
22+
device = cpu
23+
end
24+
25+
m = Chain(
26+
Dense(1, 64),
27+
FourierOperator(64=>64, (24, 24), gelu),
28+
FourierOperator(64=>64, (24, 24), gelu),
29+
FourierOperator(64=>64, (24, 24), gelu),
30+
FourierOperator(64=>64, (24, 24), gelu),
31+
Dense(64, 1),
32+
) |> device
33+
34+
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
35+
36+
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
37+
38+
@info "gen data... "
39+
@time loader_train, loader_test = get_dataloader()
40+
41+
losses = Float32[]
42+
function validate()
43+
validation_loss = sum(loss(device(𝐱), device(𝐲)) for (𝐱, 𝐲) in loader_test)/length(loader_test)
44+
@info "loss: $validation_loss"
45+
46+
push!(losses, validation_loss)
47+
(losses[end] == minimum(losses)) && update_model!(joinpath(@__DIR__, "../model/model.jld2"), m)
48+
end
49+
call_back = Flux.throttle(validate, 5, leading=false, trailing=true)
50+
51+
data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
52+
Flux.@epochs 50 @time(Flux.train!(loss, params(m), data, opt, cb=call_back))
53+
end
54+
55+
function get_model()
56+
f = jldopen(joinpath(@__DIR__, "../model/model.jld2"))
57+
model = f["model"]
58+
close(f)
59+
60+
return model
61+
end
62+
63+
end

0 commit comments

Comments
 (0)