Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit f502087

Browse files
foldfelisyuehhua
andauthored
Double Pendulum example (#17)
* add DoublePendulumChaotic dataset * build proj for double pendulum * get data from file * data visualization * add annotation * get_dataloader * revise get_dataloader * normalize data * implement MNO * revise model * Update example/DoublePendulum/test/data.jl Co-authored-by: Yueh-Hua Tu <[email protected]> * Update README.md Co-authored-by: Yueh-Hua Tu <[email protected]> * Update README.md Co-authored-by: Yueh-Hua Tu <[email protected]> * revise model * revise model and save model * fallback to last model * train MNO in 2-D and make sure loss<1e-2 about 5e-3 at epoch=35 * implement get_model and refactor * result visualization (failed) * revise data * add gradient * implement MNO * fix test * BPTT not work QQ * navier stokes equation * implement MNO on flow over a circle * bug fix * implement demo * add MNO to model * update readme Co-authored-by: Yueh-Hua Tu <[email protected]>
1 parent 9fd713b commit f502087

File tree

15 files changed

+275
-15
lines changed

15 files changed

+275
-15
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,5 @@ docs/site/
2222
# committed for packages, but should be committed for applications that require a static
2323
# environment.
2424
Manifest.toml
25+
26+
*.jld2

README.md

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ between two continuous function spaces. The kernel can be trained on different g
2020

2121
Fourier neural operator learns a neural operator with Dirichlet kernel to form a Fourier transformation. It performs Fourier transformation across infinite-dimensional function spaces and learns better than neural operator.
2222

23-
Currently, `FourierOperator` is provided in this work.
23+
Currently, the `FourierOperator` layer is provided in this work.
24+
As for model, there are `FourierNeuralOperator` and `MarkovNeuralOperator` provided. Please take a glance at them [here](src/model.jl).
2425

2526
## Usage
2627

@@ -68,25 +69,34 @@ PDE training examples are provided in `example` folder.
6869
Use following commend to train model:
6970

7071
```julia
71-
$ julia --proj=example/Burgers
72+
$ julia --proj
7273

7374
julia> using Burgers; Burgers.train()
7475
```
7576

76-
### Two-dimensional Darcy flow equation
77+
### Two-dimensional with time Navier-Stokes equation
7778

78-
WIP
79+
The Navier-Stokes equation is learned by the `MarkovNeuralOperator` with only one time step information. Example can be found in `example/FlowOverCircle`.
7980

80-
### Two-dimensional Navier-Stokes equation
81+
| **Ground Truth** | **Inferenced** |
82+
|:----------------:|:--------------:|
83+
| ![](example/FlowOverCircle/gallery/ans.gif) | ![](example/FlowOverCircle/gallery/inferenced.gif) |
8184

82-
WIP
85+
Use following commend to train model:
86+
87+
```julia
88+
$ julia --proj
89+
90+
julia> using FlowOverCircle; FlowOverCircle.train()
91+
```
8392

8493
## Roadmap
8594

8695
- [x] `FourierOperator` layer
8796
- [x] One-dimensional Burgers' equation example
88-
- [ ] Two-dimensional Darcy flow equation example
89-
- [ ] Two-dimensional Navier-Stokes equation example
97+
- [x] Two-dimensional with time Navier-Stokes equations example
98+
- [x] `MarkovNeuralOperator` model
99+
- [x] Flow over a circle prediction example
90100
- [ ] `NeuralOperator` layer
91101
- [ ] Poisson equation example
92102

example/Burgers/src/data.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ using DataDeps
22
using Fetch
33
using MAT
44

5-
export get_burgers_data
6-
75
function register_burgers()
86
register(DataDep(
97
"Burgers",
@@ -18,7 +16,7 @@ function register_burgers()
1816
))
1917
end
2018

21-
function get_burgers_data(; n=2048, Δsamples=2^3, grid_size=div(2^13, Δsamples), T=Float32)
19+
function get_data(; n=2048, Δsamples=2^3, grid_size=div(2^13, Δsamples), T=Float32)
2220
file = matopen(joinpath(datadep"Burgers", "burgers_data_R10.mat"))
2321
x_data = T.(collect(read(file, "a")[1:n, 1:Δsamples:end]'))
2422
y_data = T.(collect(read(file, "u")[1:n, 1:Δsamples:end]'))
@@ -32,7 +30,7 @@ function get_burgers_data(; n=2048, Δsamples=2^3, grid_size=div(2^13, Δsamples
3230
end
3331

3432
function get_dataloader(; n_train=1800, n_test=200, batchsize=100)
35-
𝐱, 𝐲 = get_burgers_data(n=2048)
33+
𝐱, 𝐲 = get_data(n=2048)
3634

3735
𝐱_train, 𝐲_train = 𝐱[:, :, 1:n_train], 𝐲[:, 1:n_train]
3836
loader_train = Flux.DataLoader((𝐱_train, 𝐲_train), batchsize=batchsize, shuffle=true)

example/Burgers/test/data.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@testset "get burgers data" begin
2-
xs, ys = get_burgers_data(n=1000)
2+
xs, ys = Burgers.get_data(n=1000)
33

44
@test size(xs) == (2, 1024, 1000)
55
@test size(ys) == (1024, 1000)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
name = "FlowOverCircle"
2+
uuid = "1fc04e5d-1dd1-42ff-8d75-1d53504b2476"
3+
4+
[deps]
5+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
6+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
7+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
8+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9+
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
10+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
11+
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781"
12+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
13+
WaterLily = "ed894a53-35f9-47f1-b17f-85db9237eebd"
14+
15+
[extras]
16+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
17+
18+
[targets]
19+
test = ["Test"]
806 KB
Loading
846 KB
Loading

example/FlowOverCircle/model/.gitkeep

Whitespace-only changes.
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
### A Pluto.jl notebook ###
2+
# v0.15.1
3+
4+
using Markdown
5+
using InteractiveUtils
6+
7+
# ╔═╡ 194baef2-0417-11ec-05ab-4527ef614024
8+
using Pkg; Pkg.develop(path=".."); Pkg.activate("..")
9+
10+
# ╔═╡ 38c9ced5-dcf8-4e03-ac07-7c435687861b
11+
using FlowOverCircle, Plots
12+
13+
# ╔═╡ 50ce80a3-a1e8-4ba9-a032-dad315bcb432
14+
md"
15+
# Markov Neural Operator
16+
17+
JingYu Ning
18+
"
19+
20+
# ╔═╡ 59769504-ebd5-4c6f-981f-d03826d8e34a
21+
md"
22+
This demo trains a Markov neural operator (MNO) introduced by [Zongyi Li *et al.*](https://arxiv.org/abs/2106.06898) with only one time step information. Then composed the operator to a Markov chain and inference the Navier-Stokes equations."
23+
24+
# ╔═╡ 823b3547-6723-43cf-85e6-cc6eb44efea1
25+
md"
26+
## Generate data
27+
"
28+
29+
# ╔═╡ 5268feee-bda2-4612-9d4c-a1db424a11c7
30+
begin
31+
n = 20
32+
data = FlowOverCircle.gen_data(LinRange(100, 100+n-1, n))
33+
end;
34+
35+
# ╔═╡ 9b02b6a2-33c3-4ca6-bfba-0bd74b664830
36+
begin
37+
anim = @animate for i in 1:size(data)[end]
38+
heatmap(data[1, :, :, i]', color=:coolwarm, clim=(-1.5, 1.5))
39+
scatter!(
40+
[size(data, 3)÷2], [size(data, 3)÷2-1],
41+
markersize=45, color=:black, legend=false, ticks=false
42+
)
43+
annotate!(5, 5, text("i=$i", :left))
44+
end
45+
gif(anim, fps=2)
46+
end
47+
48+
# ╔═╡ 55058635-c7e9-4ee3-81c2-0153e84f4c8e
49+
md"
50+
## Inference
51+
52+
Use the first data generated above as the initial state, and apply the operator recurrently.
53+
"
54+
55+
# ╔═╡ fbc287b8-f232-4350-9948-2091908e5a30
56+
begin
57+
m = FlowOverCircle.get_model()
58+
59+
states = Array{Float32}(undef, size(data))
60+
states[:, :, :, 1] .= view(data, :, :, :, 1)
61+
for i in 2:size(data)[end]
62+
states[:, :, :, i:i] .= m(view(states, :, :, :, i-1:i-1))
63+
end
64+
end
65+
66+
# ╔═╡ a0b5e94c-a839-4cc0-a325-1a4ac39fafbc
67+
begin
68+
anim_model = @animate for i in 1:size(states)[end]
69+
heatmap(states[1, :, :, i]', color=:coolwarm, clim=(-1.5, 1.5))
70+
scatter!(
71+
[size(data, 3)÷2], [size(data, 3)÷2-1],
72+
markersize=45, color=:black, legend=false, ticks=false
73+
)
74+
annotate!(5, 5, text("i=$i", :left))
75+
end
76+
gif(anim_model, fps=2)
77+
end
78+
79+
# ╔═╡ Cell order:
80+
# ╟─50ce80a3-a1e8-4ba9-a032-dad315bcb432
81+
# ╟─59769504-ebd5-4c6f-981f-d03826d8e34a
82+
# ╟─194baef2-0417-11ec-05ab-4527ef614024
83+
# ╠═38c9ced5-dcf8-4e03-ac07-7c435687861b
84+
# ╟─823b3547-6723-43cf-85e6-cc6eb44efea1
85+
# ╠═5268feee-bda2-4612-9d4c-a1db424a11c7
86+
# ╟─9b02b6a2-33c3-4ca6-bfba-0bd74b664830
87+
# ╟─55058635-c7e9-4ee3-81c2-0153e84f4c8e
88+
# ╠═fbc287b8-f232-4350-9948-2091908e5a30
89+
# ╟─a0b5e94c-a839-4cc0-a325-1a4ac39fafbc
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
module FlowOverCircle
2+
3+
using NeuralOperators
4+
using Flux
5+
using CUDA
6+
using JLD2
7+
8+
include("data.jl")
9+
10+
function update_model!(model_file_path, model)
11+
model = cpu(model)
12+
jldsave(model_file_path; model)
13+
@warn "model updated!"
14+
end
15+
16+
function train()
17+
if has_cuda()
18+
@info "CUDA is on"
19+
device = gpu
20+
CUDA.allowscalar(false)
21+
else
22+
device = cpu
23+
end
24+
25+
m = Chain(
26+
Dense(1, 64),
27+
FourierOperator(64=>64, (24, 24), gelu),
28+
FourierOperator(64=>64, (24, 24), gelu),
29+
FourierOperator(64=>64, (24, 24), gelu),
30+
FourierOperator(64=>64, (24, 24), gelu),
31+
Dense(64, 1),
32+
) |> device
33+
34+
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
35+
36+
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
37+
38+
@info "gen data... "
39+
@time loader_train, loader_test = get_dataloader()
40+
41+
losses = Float32[]
42+
function validate()
43+
validation_loss = sum(loss(device(𝐱), device(𝐲)) for (𝐱, 𝐲) in loader_test)/length(loader_test)
44+
@info "loss: $validation_loss"
45+
46+
push!(losses, validation_loss)
47+
(losses[end] == minimum(losses)) && update_model!(joinpath(@__DIR__, "../model/model.jld2"), m)
48+
end
49+
call_back = Flux.throttle(validate, 5, leading=false, trailing=true)
50+
51+
data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
52+
Flux.@epochs 50 @time(Flux.train!(loss, params(m), data, opt, cb=call_back))
53+
end
54+
55+
function get_model()
56+
f = jldopen(joinpath(@__DIR__, "../model/model.jld2"))
57+
model = f["model"]
58+
close(f)
59+
60+
return model
61+
end
62+
63+
end

0 commit comments

Comments
 (0)