Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit a959e4f

Browse files
committed
implement super resolution mno
1 parent f4d42f4 commit a959e4f

File tree

6 files changed

+33
-19
lines changed

6 files changed

+33
-19
lines changed

example/FlowOverCircle/src/data.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function gen_data(ts::AbstractRange)
2222
𝐩s = Array{Float32}(undef, 1, n, m, length(ts))
2323
for (i, t) in enumerate(ts)
2424
sim_step!(circ, t)
25-
𝐩s[:, :, :, i] = Float32.(circ.flow.p)[2:end-1, 2:end-1]
25+
𝐩s[1, :, :, i] .= Float32.(circ.flow.p)[2:end-1, 2:end-1]
2626
end
2727

2828
return 𝐩s

example/SuperResolution/README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
# Super Resolution
22

33
The time dependent Navier-Stokes equation is learned by the `MarkovNeuralOperator` with only one time step information.
4-
The result of this example can be found [here](https://foldfelis.github.io/NeuralOperators.jl/dev/assets/notebook/mno.jl.html).
4+
The result of this example can be found [here](https://foldfelis.github.io/NeuralOperators.jl/dev/assets/notebook/super_resolution_mno.jl.html).
5+
6+
Apart from just training a MNO, here, we train the model with lower resolution (96x64) and inference result with higher resolution (192x128).
57

68
| **Ground Truth** | **Inferenced** |
79
|:----------------:|:--------------:|
810
| ![](gallery/ans.gif) | ![](gallery/inferenced.gif) |
911

10-
Change directory to `example/FlowOverCircle` and use following commend to train model:
12+
Change directory to `example/SuperResolution` and use following commend to train model:
1113

1214
```julia
1315
$ julia --proj
1416

15-
julia> using FlowOverCircle; FlowOverCircle.train()
17+
julia> using SuperResolution; SuperResolution.train()
1618
```
-398 KB
Loading
-332 KB
Loading

example/SuperResolution/notebook/mno.jl renamed to example/SuperResolution/notebook/super_resolution_mno.jl

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
### A Pluto.jl notebook ###
2-
# v0.15.1
2+
# v0.16.1
33

44
using Markdown
55
using InteractiveUtils
@@ -8,11 +8,11 @@ using InteractiveUtils
88
using Pkg; Pkg.develop(path=".."); Pkg.activate("..")
99

1010
# ╔═╡ 38c9ced5-dcf8-4e03-ac07-7c435687861b
11-
using FlowOverCircle, Plots
11+
using SuperResolution, Plots
1212

1313
# ╔═╡ 50ce80a3-a1e8-4ba9-a032-dad315bcb432
1414
md"
15-
# Markov Neural Operator
15+
# Super Resolution with MNO
1616
1717
JingYu Ning
1818
"
@@ -28,16 +28,26 @@ md"
2828

2929
# ╔═╡ 5268feee-bda2-4612-9d4c-a1db424a11c7
3030
begin
31-
n = 20
32-
data = FlowOverCircle.gen_data(LinRange(100, 100+n-1, n))
31+
n = 10
32+
data = SuperResolution.gen_data(LinRange(100, 100+n-1, n))
3333
end;
3434

35-
# ╔═╡ 9b02b6a2-33c3-4ca6-bfba-0bd74b664830
35+
# ╔═╡ 5531bba6-94bd-4c99-be8c-43fe19ad8a60
36+
md"
37+
## Training
38+
"
39+
40+
# ╔═╡ 74fc528f-ccd4-4670-9b17-dbfa7a1c74b6
41+
md"
42+
Apart from just training a MNO, here, we train the model with lower resolution (96x64) and inference result with higher resolution (192x128).
43+
"
44+
45+
# ╔═╡ f6d1ce85-a195-4ab1-bd3a-dbd4b0d1fcca
3646
begin
3747
anim = @animate for i in 1:size(data)[end]
38-
heatmap(data[1, :, :, i]', color=:coolwarm, clim=(-1.5, 1.5))
48+
heatmap(data[1, 1:2:end, 1:2:end, i]', color=:coolwarm, clim=(-1.5, 1.5))
3949
scatter!(
40-
[size(data, 3)÷2], [size(data, 3)÷2-1],
50+
[size(data, 3)÷4-1], [size(data, 3)÷4-1],
4151
markersize=45, color=:black, legend=false, ticks=false
4252
)
4353
annotate!(5, 5, text("i=$i", :left))
@@ -54,7 +64,7 @@ Use the first data generated above as the initial state, and apply the operator
5464

5565
# ╔═╡ fbc287b8-f232-4350-9948-2091908e5a30
5666
begin
57-
m = FlowOverCircle.get_model()
67+
m = SuperResolution.get_model()
5868

5969
states = Array{Float32}(undef, size(data))
6070
states[:, :, :, 1] .= view(data, :, :, :, 1)
@@ -68,7 +78,7 @@ begin
6878
anim_model = @animate for i in 1:size(states)[end]
6979
heatmap(states[1, :, :, i]', color=:coolwarm, clim=(-1.5, 1.5))
7080
scatter!(
71-
[size(data, 3)÷2], [size(data, 3)÷2-1],
81+
[size(data, 3)÷2-1], [size(data, 3)÷2-1],
7282
markersize=45, color=:black, legend=false, ticks=false
7383
)
7484
annotate!(5, 5, text("i=$i", :left))
@@ -79,11 +89,13 @@ end
7989
# ╔═╡ Cell order:
8090
# ╟─50ce80a3-a1e8-4ba9-a032-dad315bcb432
8191
# ╟─59769504-ebd5-4c6f-981f-d03826d8e34a
82-
# ╟─194baef2-0417-11ec-05ab-4527ef614024
92+
# ╠═194baef2-0417-11ec-05ab-4527ef614024
8393
# ╠═38c9ced5-dcf8-4e03-ac07-7c435687861b
8494
# ╟─823b3547-6723-43cf-85e6-cc6eb44efea1
8595
# ╠═5268feee-bda2-4612-9d4c-a1db424a11c7
86-
# ╟─9b02b6a2-33c3-4ca6-bfba-0bd74b664830
96+
# ╟─5531bba6-94bd-4c99-be8c-43fe19ad8a60
97+
# ╟─74fc528f-ccd4-4670-9b17-dbfa7a1c74b6
98+
# ╠═f6d1ce85-a195-4ab1-bd3a-dbd4b0d1fcca
8799
# ╟─55058635-c7e9-4ee3-81c2-0153e84f4c8e
88100
# ╠═fbc287b8-f232-4350-9948-2091908e5a30
89101
# ╟─a0b5e94c-a839-4cc0-a325-1a4ac39fafbc

example/SuperResolution/src/data.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ function circle(n, m; Re=250)
1616
end
1717

1818
function gen_data(ts::AbstractRange)
19-
n, m = 3(2^5), 2^6
19+
n, m = 2 * 3(2^5), 2 * 2^6
2020
circ = circle(n, m)
2121

2222
𝐩s = Array{Float32}(undef, 1, n, m, length(ts))
2323
for (i, t) in enumerate(ts)
2424
sim_step!(circ, t)
25-
𝐩s[:, :, :, i] = Float32.(circ.flow.p)[2:end-1, 2:end-1]
25+
𝐩s[1, :, :, i] .= Float32.(circ.flow.p)[2:end-1, 2:end-1]
2626
end
2727

2828
return 𝐩s
@@ -33,7 +33,7 @@ function get_dataloader(; ts::AbstractRange=LinRange(100, 11000, 10000), ratio::
3333

3434
n_train, n_test = floor(Int, length(ts)*ratio), floor(Int, length(ts)*(1-ratio))
3535

36-
𝐱_train, 𝐲_train = data[:, :, :, 1:(n_train-1)], data[:, :, :, 2:n_train]
36+
𝐱_train, 𝐲_train = data[:, 1:2:end, 1:2:end, 1:(n_train-1)], data[:, 1:2:end, 1:2:end, 2:n_train]
3737
loader_train = Flux.DataLoader((𝐱_train, 𝐲_train), batchsize=batchsize, shuffle=true)
3838

3939
𝐱_test, 𝐲_test = data[:, :, :, (end-n_test+1):(end-1)], data[:, :, :, (end-n_test+2):end]

0 commit comments

Comments
 (0)