Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit f4d42f4

Browse files
committed
copy example from flow over circle
1 parent 05873a1 commit f4d42f4

File tree

10 files changed

+239
-0
lines changed

10 files changed

+239
-0
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
name = "SuperResolution"
2+
uuid = "a8258e1f-331c-4af2-83e9-878628278453"
3+
4+
[deps]
5+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
6+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
7+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
8+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9+
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
10+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
11+
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781"
12+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
13+
WaterLily = "ed894a53-35f9-47f1-b17f-85db9237eebd"
14+
15+
[extras]
16+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
17+
18+
[targets]
19+
test = ["Test"]

example/SuperResolution/README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Super Resolution
2+
3+
The time dependent Navier-Stokes equation is learned by the `MarkovNeuralOperator` with only one time step information.
4+
The result of this example can be found [here](https://foldfelis.github.io/NeuralOperators.jl/dev/assets/notebook/mno.jl.html).
5+
6+
| **Ground Truth** | **Inferenced** |
7+
|:----------------:|:--------------:|
8+
| ![](gallery/ans.gif) | ![](gallery/inferenced.gif) |
9+
10+
Change directory to `example/FlowOverCircle` and use following commend to train model:
11+
12+
```julia
13+
$ julia --proj
14+
15+
julia> using FlowOverCircle; FlowOverCircle.train()
16+
```
806 KB
Loading
846 KB
Loading

example/SuperResolution/model/.gitkeep

Whitespace-only changes.
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
### A Pluto.jl notebook ###
2+
# v0.15.1
3+
4+
using Markdown
5+
using InteractiveUtils
6+
7+
# ╔═╡ 194baef2-0417-11ec-05ab-4527ef614024
8+
using Pkg; Pkg.develop(path=".."); Pkg.activate("..")
9+
10+
# ╔═╡ 38c9ced5-dcf8-4e03-ac07-7c435687861b
11+
using FlowOverCircle, Plots
12+
13+
# ╔═╡ 50ce80a3-a1e8-4ba9-a032-dad315bcb432
14+
md"
15+
# Markov Neural Operator
16+
17+
JingYu Ning
18+
"
19+
20+
# ╔═╡ 59769504-ebd5-4c6f-981f-d03826d8e34a
21+
md"
22+
This demo trains a Markov neural operator (MNO) introduced by [Zongyi Li *et al.*](https://arxiv.org/abs/2106.06898) with only one time step information. Then composed the operator to a Markov chain and inference the Navier-Stokes equations."
23+
24+
# ╔═╡ 823b3547-6723-43cf-85e6-cc6eb44efea1
25+
md"
26+
## Generate data
27+
"
28+
29+
# ╔═╡ 5268feee-bda2-4612-9d4c-a1db424a11c7
30+
begin
31+
n = 20
32+
data = FlowOverCircle.gen_data(LinRange(100, 100+n-1, n))
33+
end;
34+
35+
# ╔═╡ 9b02b6a2-33c3-4ca6-bfba-0bd74b664830
36+
begin
37+
anim = @animate for i in 1:size(data)[end]
38+
heatmap(data[1, :, :, i]', color=:coolwarm, clim=(-1.5, 1.5))
39+
scatter!(
40+
[size(data, 3)÷2], [size(data, 3)÷2-1],
41+
markersize=45, color=:black, legend=false, ticks=false
42+
)
43+
annotate!(5, 5, text("i=$i", :left))
44+
end
45+
gif(anim, fps=2)
46+
end
47+
48+
# ╔═╡ 55058635-c7e9-4ee3-81c2-0153e84f4c8e
49+
md"
50+
## Inference
51+
52+
Use the first data generated above as the initial state, and apply the operator recurrently.
53+
"
54+
55+
# ╔═╡ fbc287b8-f232-4350-9948-2091908e5a30
56+
begin
57+
m = FlowOverCircle.get_model()
58+
59+
states = Array{Float32}(undef, size(data))
60+
states[:, :, :, 1] .= view(data, :, :, :, 1)
61+
for i in 2:size(data)[end]
62+
states[:, :, :, i:i] .= m(view(states, :, :, :, i-1:i-1))
63+
end
64+
end
65+
66+
# ╔═╡ a0b5e94c-a839-4cc0-a325-1a4ac39fafbc
67+
begin
68+
anim_model = @animate for i in 1:size(states)[end]
69+
heatmap(states[1, :, :, i]', color=:coolwarm, clim=(-1.5, 1.5))
70+
scatter!(
71+
[size(data, 3)÷2], [size(data, 3)÷2-1],
72+
markersize=45, color=:black, legend=false, ticks=false
73+
)
74+
annotate!(5, 5, text("i=$i", :left))
75+
end
76+
gif(anim_model, fps=2)
77+
end
78+
79+
# ╔═╡ Cell order:
80+
# ╟─50ce80a3-a1e8-4ba9-a032-dad315bcb432
81+
# ╟─59769504-ebd5-4c6f-981f-d03826d8e34a
82+
# ╟─194baef2-0417-11ec-05ab-4527ef614024
83+
# ╠═38c9ced5-dcf8-4e03-ac07-7c435687861b
84+
# ╟─823b3547-6723-43cf-85e6-cc6eb44efea1
85+
# ╠═5268feee-bda2-4612-9d4c-a1db424a11c7
86+
# ╟─9b02b6a2-33c3-4ca6-bfba-0bd74b664830
87+
# ╟─55058635-c7e9-4ee3-81c2-0153e84f4c8e
88+
# ╠═fbc287b8-f232-4350-9948-2091908e5a30
89+
# ╟─a0b5e94c-a839-4cc0-a325-1a4ac39fafbc
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
module SuperResolution
2+
3+
using NeuralOperators
4+
using Flux
5+
using CUDA
6+
using JLD2
7+
8+
include("data.jl")
9+
10+
function update_model!(model_file_path, model)
11+
model = cpu(model)
12+
jldsave(model_file_path; model)
13+
@warn "model updated!"
14+
end
15+
16+
function train()
17+
if has_cuda()
18+
@info "CUDA is on"
19+
device = gpu
20+
CUDA.allowscalar(false)
21+
else
22+
device = cpu
23+
end
24+
25+
m = Chain(
26+
Dense(1, 64),
27+
FourierOperator(64=>64, (24, 24), gelu),
28+
FourierOperator(64=>64, (24, 24), gelu),
29+
FourierOperator(64=>64, (24, 24), gelu),
30+
FourierOperator(64=>64, (24, 24), gelu),
31+
Dense(64, 1),
32+
) |> device
33+
34+
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
35+
36+
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
37+
38+
@info "gen data... "
39+
@time loader_train, loader_test = get_dataloader()
40+
41+
losses = Float32[]
42+
function validate()
43+
validation_loss = sum(loss(device(𝐱), device(𝐲)) for (𝐱, 𝐲) in loader_test)/length(loader_test)
44+
@info "loss: $validation_loss"
45+
46+
push!(losses, validation_loss)
47+
(losses[end] == minimum(losses)) && update_model!(joinpath(@__DIR__, "../model/model.jld2"), m)
48+
end
49+
call_back = Flux.throttle(validate, 5, leading=false, trailing=true)
50+
51+
data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
52+
Flux.@epochs 50 @time(Flux.train!(loss, params(m), data, opt, cb=call_back))
53+
end
54+
55+
function get_model()
56+
f = jldopen(joinpath(@__DIR__, "../model/model.jld2"))
57+
model = f["model"]
58+
close(f)
59+
60+
return model
61+
end
62+
63+
end
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
using WaterLily
2+
using LinearAlgebra: norm2
3+
4+
"""
5+
circle(n, m; Re=250)
6+
7+
This function is copy from [WaterLily](https://github.com/weymouth/WaterLily.jl)
8+
"""
9+
function circle(n, m; Re=250)
10+
# Set physical parameters
11+
U, R, center = 1., m/8., [m/2, m/2]
12+
ν = U * R / Re
13+
14+
body = AutoBody((x,t) -> norm2(x .- center) - R)
15+
Simulation((n+2, m+2), [U, 0.], R; ν, body)
16+
end
17+
18+
function gen_data(ts::AbstractRange)
19+
n, m = 3(2^5), 2^6
20+
circ = circle(n, m)
21+
22+
𝐩s = Array{Float32}(undef, 1, n, m, length(ts))
23+
for (i, t) in enumerate(ts)
24+
sim_step!(circ, t)
25+
𝐩s[:, :, :, i] = Float32.(circ.flow.p)[2:end-1, 2:end-1]
26+
end
27+
28+
return 𝐩s
29+
end
30+
31+
function get_dataloader(; ts::AbstractRange=LinRange(100, 11000, 10000), ratio::Float64=0.95, batchsize=100)
32+
data = gen_data(ts)
33+
34+
n_train, n_test = floor(Int, length(ts)*ratio), floor(Int, length(ts)*(1-ratio))
35+
36+
𝐱_train, 𝐲_train = data[:, :, :, 1:(n_train-1)], data[:, :, :, 2:n_train]
37+
loader_train = Flux.DataLoader((𝐱_train, 𝐲_train), batchsize=batchsize, shuffle=true)
38+
39+
𝐱_test, 𝐲_test = data[:, :, :, (end-n_test+1):(end-1)], data[:, :, :, (end-n_test+2):end]
40+
loader_test = Flux.DataLoader((𝐱_test, 𝐲_test), batchsize=batchsize, shuffle=false)
41+
42+
return loader_train, loader_test
43+
end
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
@testset "flow over circle" begin
2+
3+
end
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
using SuperResolution
2+
using Test
3+
4+
@testset "SuperResolution" begin
5+
include("data.jl")
6+
end

0 commit comments

Comments
 (0)