Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 9695c35

Browse files
authored
Merge pull request #89 from KirillZubov/burger_visualization
update Burger example
2 parents 6ad9828 + 518a0a1 commit 9695c35

File tree

5 files changed

+53
-1
lines changed

5 files changed

+53
-1
lines changed

example/Burgers/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@
22

33
In this example, a [Burgers' equation](https://en.wikipedia.org/wiki/Burgers%27_equation)
44
is learned by a one-dimensional Fourier neural operator network.
5+
6+
there is learn the operator mapping the initial condition to last point of time evolition of equation in some function space :
7+
8+
```math
9+
u(x, 0) -> u(x, t_end)\
10+
```
11+
12+
![](gallery/burgers.png)
13+
514
Change directory to `example/Burgers` and use following commend to train model:
615

716
```julia
52 KB
Loading

example/Burgers/notebook/mno.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using Burgers, Plots
2+
using DataDeps, MAT, MLUtils
3+
using NeuralOperators, Flux
4+
using CUDA, BSON
5+
6+
dataset = Burgers.get_data(n = 1000);
7+
8+
m = Burgers.get_model();
9+
input_data, ground_truth = dataset[1], dataset[2];
10+
11+
12+
i = 1
13+
plot(input_data[1,:,1],ground_truth[1, :, i], label = "ground_truth",title = " Burgers equation u(x,T_end)");
14+
p1 = plot!(input_data[1,:,1],m(view(input_data, :, :, i:i))[1, :, 1],label = "predict");
15+
plot(input_data[1,:,1],ground_truth[1, :, i + 1],label = "ground_truth");
16+
p2 = plot!(input_data[1,:,1],m(view(input_data, :, :, (i + 1):(i + 1)))[1, :, 1],label = "predict");
17+
i = 3
18+
19+
plot(input_data[1,:,1],ground_truth[1, :, i], label = "ground_truth");
20+
p3 = plot!(input_data[1,:,1],m(view(input_data, :, :, i:i))[1, :, 1],label = "predict");
21+
plot(input_data[1,:,1],ground_truth[1, :, i + 1], label = "ground_truth");
22+
p4 = plot!(input_data[1,:,1],m(view(input_data, :, :, (i + 1):(i + 1)))[1, :, 1], label = "predict");
23+
p = plot(p1, p2, p3, p4)
24+
25+

example/Burgers/src/Burgers.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using DataDeps, MAT, MLUtils
44
using NeuralOperators, Flux
55
using CUDA, FluxTraining, BSON
66
import Flux: params
7+
using BSON: @save, @load
78

89
include("Burgers_deeponet.jl")
910

@@ -12,7 +13,14 @@ function register_burgers()
1213
"""
1314
Burgers' equation dataset from
1415
[fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator)
16+
17+
mapping between initial conditions to the solutions at the last point of time evolition in some function space.
18+
u(x,0) -> u(x, time_end):
19+
20+
* `a`: initial conditions u(x,0)
21+
* `u`: solutions u(x,t_end)
1522
""",
23+
1624
"http://www.med.cgu.edu.tw/NeuralOperators/Burgers_R10.zip",
1725
"9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd",
1826
post_fetch_method = unpack))
@@ -63,6 +71,8 @@ function train(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 500)
6371
ToDevice(device, device))
6472

6573
fit!(learner, epochs)
74+
model = learner.model |> cpu
75+
@save "model/model_burger.bson" model
6676

6777
return learner
6878
end
@@ -104,4 +114,12 @@ function train_nomad(; n = 300, cuda = true, learning_rate = 0.001, epochs = 400
104114
return mean_diff
105115
end
106116

117+
118+
function get_model()
119+
model_path = joinpath(@__DIR__, "../model/")
120+
model_file = readdir(model_path)[end]
121+
122+
return BSON.load(joinpath(model_path, model_file), @__MODULE__)[:model]
123+
end
124+
107125
end

example/Burgers/test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using Test
88
@test size(xs) == (2, 1024, 1000)
99
@test size(ys) == (1, 1024, 1000)
1010

11-
learner = Burgers.train(epochs = 10)
11+
learner = Burgers.train(epochs = 100)
1212
loss = learner.cbstate.metricsepoch[ValidationPhase()][:Loss].values[end]
1313
@test loss < 0.1
1414

0 commit comments

Comments
 (0)