Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 6a390d9

Browse files
authored
Merge pull request #6 from foldfelis/example
Example for Burger's equation
2 parents c797169 + a4ce7b4 commit 6a390d9

File tree

5 files changed

+50
-7
lines changed

5 files changed

+50
-7
lines changed

Project.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@ Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
1313
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1414

1515
[compat]
16+
DataDeps = "0.7"
17+
FFTW = "1.4"
18+
Fetch = "0.1"
19+
Flux = "0.12"
20+
MAT = "0.10"
21+
Tullio = "0.3"
22+
Zygote = "0.6"
1623
julia = "1.6"
1724

1825
[extras]

example/burgers.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using NeuralOperators
2+
using Flux
3+
# using CUDA
4+
5+
# if has_cuda()
6+
# @info "CUDA is on"
7+
# device = gpu
8+
# CUDA.allowscalar(false)
9+
# else
10+
device = cpu
11+
# end
12+
13+
m = FourierNeuralOperator() |> device
14+
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
15+
16+
𝐱, 𝐲 = get_burgers_data(n=2048)
17+
18+
n_train = 2000
19+
𝐱_train, 𝐲_train = 𝐱[:, :, 1:n_train], 𝐲[:, 1:n_train]
20+
loader_train = Flux.DataLoader((𝐱_train, 𝐲_train), batchsize=20, shuffle=true)
21+
22+
n_test = 40
23+
𝐱_test, 𝐲_test = 𝐱[:, :, end-n_test+1:end], 𝐲[:, end-n_test+1:end]
24+
loader_test = Flux.DataLoader((𝐱_test, 𝐲_test), batchsize=20, shuffle=false)
25+
26+
function loss_test()
27+
l = 0
28+
for (𝐱, 𝐲) in loader_test
29+
l += loss(𝐱, 𝐲)
30+
end
31+
@info "loss: $(l/length(loader_test))"
32+
end
33+
34+
data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
35+
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
36+
Flux.@epochs 500 @time(Flux.train!(loss, params(m), data, opt, cb=Flux.throttle(loss_test, 10)))

src/data.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ function register_datasets()
1515
))
1616
end
1717

18-
function get_burgers_data(; n=1000, Δsamples=2^3, grid_size=div(2^13, Δsamples))
18+
function get_burgers_data(; n=2048, Δsamples=2^3, grid_size=div(2^13, Δsamples), T=Float32)
1919
file = matopen(joinpath(datadep"BurgersR10", "burgers_data_R10.mat"))
20-
x_data = collect(read(file, "a")[1:n, 1:Δsamples:end]')
21-
y_data = collect(read(file, "u")[1:n, 1:Δsamples:end]')
20+
x_data = T.(collect(read(file, "a")[1:n, 1:Δsamples:end]'))
21+
y_data = T.(collect(read(file, "u")[1:n, 1:Δsamples:end]'))
2222
close(file)
2323

24-
x_loc_data = Array{Float32, 3}(undef, 2, grid_size, n)
24+
x_loc_data = Array{T, 3}(undef, 2, grid_size, n)
2525
x_loc_data[1, :, :] .= reshape(repeat(LinRange(0, 1, grid_size), n), (grid_size, n))
2626
x_loc_data[2, :, :] .= x_data
2727

test/data.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@testset "get burgers data" begin
2-
xs, ys = get_burgers_data()
2+
xs, ys = get_burgers_data(n=1000)
33

44
@test size(xs) == (2, 1024, 1000)
55
@test size(ys) == (1024, 1000)

test/fourier.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
SpectralConv1d(ch, modes)
88
)
99

10-
𝐱, _ = get_burgers_data()
10+
𝐱, _ = get_burgers_data(n=1000)
1111
@test size(m(𝐱)) == (64, 1024, 1000)
1212

1313
T = Float32
@@ -25,7 +25,7 @@ end
2525
FourierOperator(ch, modes)
2626
)
2727

28-
𝐱, _ = get_burgers_data()
28+
𝐱, _ = get_burgers_data(n=1000)
2929
@test size(m(𝐱)) == (64, 1024, 1000)
3030

3131
loss(x, y) = Flux.mse(m(x), y)

0 commit comments

Comments
 (0)