Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit cf8eddd

Browse files
authored
Merge pull request #16 from foldfelis/n-d_example
Fix functor bug and build project for Burgers' equation
2 parents fa998db + 29d55da commit cf8eddd

File tree

13 files changed

+144
-171
lines changed

13 files changed

+144
-171
lines changed

example/Burgers/Project.toml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
name = "Burgers"
2+
uuid = "5b053d85-f964-4905-ae31-99551cd8d3ad"
3+
4+
[deps]
5+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
6+
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
7+
Fetch = "bb354801-46f6-40b6-9c3d-d42d7a74c775"
8+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
9+
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
10+
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
11+
12+
[extras]
13+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
14+
15+
[targets]
16+
test = ["Test"]

example/Burgers/src/Burgers.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
module Burgers
2+
3+
using NeuralOperators
4+
using Flux
5+
using CUDA
6+
7+
include("data.jl")
8+
9+
__init__() = register_burgers()
10+
11+
function train()
12+
if has_cuda()
13+
@info "CUDA is on"
14+
device = gpu
15+
CUDA.allowscalar(false)
16+
else
17+
device = cpu
18+
end
19+
20+
modes = (16, )
21+
ch = 64 => 64
22+
σ = gelu
23+
m = Chain(
24+
Dense(2, 64),
25+
FourierOperator(ch, modes, σ),
26+
FourierOperator(ch, modes, σ),
27+
FourierOperator(ch, modes, σ),
28+
FourierOperator(ch, modes),
29+
Dense(64, 128, σ),
30+
Dense(128, 1),
31+
flatten
32+
) |> device
33+
34+
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
35+
36+
loader_train, loader_test = get_dataloader()
37+
38+
function validate()
39+
validation_losses = [loss(device(𝐱), device(𝐲)) for (𝐱, 𝐲) in loader_test]
40+
@info "loss: $(sum(validation_losses)/length(loader_test))"
41+
end
42+
43+
data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
44+
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
45+
call_back = Flux.throttle(validate, 5, leading=false, trailing=true)
46+
Flux.@epochs 500 @time(Flux.train!(loss, params(m), data, opt, cb=call_back))
47+
end
48+
49+
end

example/Burgers/src/data.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
using DataDeps
2+
using Fetch
3+
using MAT
4+
5+
export get_burgers_data
6+
7+
function register_burgers()
8+
register(DataDep(
9+
"Burgers",
10+
"""
11+
Burgers' equation dataset from
12+
[fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator)
13+
""",
14+
"https://drive.google.com/file/d/17MYsKzxUQVaLMWodzPbffR8hhDHoadPp/view?usp=sharing",
15+
"9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd",
16+
fetch_method=gdownload,
17+
post_fetch_method=unpack
18+
))
19+
end
20+
21+
function get_burgers_data(; n=2048, Δsamples=2^3, grid_size=div(2^13, Δsamples), T=Float32)
22+
file = matopen(joinpath(datadep"Burgers", "burgers_data_R10.mat"))
23+
x_data = T.(collect(read(file, "a")[1:n, 1:Δsamples:end]'))
24+
y_data = T.(collect(read(file, "u")[1:n, 1:Δsamples:end]'))
25+
close(file)
26+
27+
x_loc_data = Array{T, 3}(undef, 2, grid_size, n)
28+
x_loc_data[1, :, :] .= reshape(repeat(LinRange(0, 1, grid_size), n), (grid_size, n))
29+
x_loc_data[2, :, :] .= x_data
30+
31+
return x_loc_data, y_data
32+
end
33+
34+
function get_dataloader(; n_train=1800, n_test=200, batchsize=100)
35+
𝐱, 𝐲 = get_burgers_data(n=2048)
36+
37+
𝐱_train, 𝐲_train = 𝐱[:, :, 1:n_train], 𝐲[:, 1:n_train]
38+
loader_train = Flux.DataLoader((𝐱_train, 𝐲_train), batchsize=batchsize, shuffle=true)
39+
40+
𝐱_test, 𝐲_test = 𝐱[:, :, end-n_test+1:end], 𝐲[:, end-n_test+1:end]
41+
loader_test = Flux.DataLoader((𝐱_test, 𝐲_test), batchsize=batchsize, shuffle=false)
42+
43+
return loader_train, loader_test
44+
end

example/Burgers/test/data.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
@testset "get burgers data" begin
2+
xs, ys = get_burgers_data(n=1000)
3+
4+
@test size(xs) == (2, 1024, 1000)
5+
@test size(ys) == (1024, 1000)
6+
end

example/Burgers/test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
using Burgers
2+
using Test
3+
4+
@testset "Burgers" begin
5+
include("data.jl")
6+
end

example/burgers.jl

Lines changed: 0 additions & 35 deletions
This file was deleted.

src/NeuralOperators.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
11
module NeuralOperators
2-
using DataDeps
3-
using Fetch
4-
using MAT
5-
using StatsBase
6-
72
using Flux
83
using FFTW
94
using Tullio
@@ -13,11 +8,6 @@ module NeuralOperators
138
using Zygote
149
using ChainRulesCore
1510

16-
function __init__()
17-
register_datasets()
18-
end
19-
20-
include("data.jl")
2111
include("fourier.jl")
2212
include("model.jl")
2313
end

src/data.jl

Lines changed: 0 additions & 85 deletions
This file was deleted.

src/fourier.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,25 @@ export
33
FourierOperator
44

55
struct SpectralConv{P, N, T, S, F}
6+
permuted::Bool
67
weight::T
78
in_channel::S
89
out_channel::S
910
modes::NTuple{N, S}
1011
σ::F
1112
end
1213

14+
function SpectralConv(
15+
permuted::Bool,
16+
weight::T,
17+
in_channel::S,
18+
out_channel::S,
19+
modes::NTuple{N, S},
20+
σ::F
21+
) where {N, T, S, F}
22+
return SpectralConv{permuted, N, T, S, F}(permuted, weight, in_channel, out_channel, modes, σ)
23+
end
24+
1325
"""
1426
SpectralConv(
1527
ch, modes, σ=identity;
@@ -50,18 +62,14 @@ function SpectralConv(
5062
in_chs, out_chs = ch
5163
scale = one(T) / (in_chs * out_chs)
5264
weights = scale * init(out_chs, in_chs, prod(modes))
53-
W = typeof(weights)
54-
F = typeof(σ)
5565

56-
return SpectralConv{permuted,N,W,S,F}(weights, in_chs, out_chs, modes, σ)
66+
return SpectralConv(permuted, weights, in_chs, out_chs, modes, σ)
5767
end
5868

5969
Flux.@functor SpectralConv
6070

6171
Base.ndims(::SpectralConv{P,N}) where {P,N} = N
6272

63-
permuted(::SpectralConv{P}) where {P} = P
64-
6573
function Base.show(io::IO, l::SpectralConv{P}) where {P}
6674
print(io, "SpectralConv($(l.in_channel) => $(l.out_channel), $(l.modes), σ=$(string(l.σ)), permuted=$P)")
6775
end
@@ -142,7 +150,7 @@ end
142150
Flux.@functor FourierOperator
143151

144152
function Base.show(io::IO, l::FourierOperator)
145-
print(io, "FourierOperator($(l.conv.in_channel) => $(l.conv.out_channel), $(l.conv.modes), σ=$(string(l.σ)), permuted=$(permuted(l.conv)))")
153+
print(io, "FourierOperator($(l.conv.in_channel) => $(l.conv.out_channel), $(l.conv.modes), σ=$(string(l.σ)), permuted=$(l.conv.permuted))")
146154
end
147155

148156
function (m::FourierOperator)(𝐱)

test/data.jl

Lines changed: 0 additions & 24 deletions
This file was deleted.

0 commit comments

Comments
 (0)