Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit fba4619

Browse files
authored
Merge pull request #2 from foldfelis/fno
Implement 1-D Fourier Neural Operator
2 parents 673ca1e + 6027fdd commit fba4619

File tree

8 files changed

+186
-3
lines changed

8 files changed

+186
-3
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
/Manifest.toml
2+
data

Project.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@ uuid = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
33
authors = ["JingYu Ning <[email protected]> and contributors"]
44
version = "0.1.0"
55

6+
[deps]
7+
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
8+
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
9+
Fetch = "bb354801-46f6-40b6-9c3d-d42d7a74c775"
10+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
11+
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
12+
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
13+
614
[compat]
715
julia = "1.6"
816

src/NeuralOperators.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
module NeuralOperators
2+
function __init__()
3+
register_datasets()
4+
end
25

3-
# Write your package code here.
4-
6+
include("preprocess.jl")
7+
include("fourier.jl")
58
end

src/fourier.jl

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
using Flux
2+
using FFTW
3+
using Tullio
4+
5+
export
6+
SpectralConv1d,
7+
FourierOperator,
8+
FNO
9+
10+
c_glorot_uniform(dims...) = Flux.glorot_uniform(dims...) + Flux.glorot_uniform(dims...) * im
11+
12+
struct SpectralConv1d{T, S}
13+
weight::T
14+
in_channel::S
15+
out_channel::S
16+
modes::S
17+
σ
18+
end
19+
20+
function SpectralConv1d(
21+
ch::Pair{<:Integer,<:Integer},
22+
modes::Integer,
23+
σ=identity;
24+
init=c_glorot_uniform,
25+
T::DataType=ComplexF32
26+
)
27+
in_chs, out_chs = ch
28+
scale = one(T) / (in_chs * out_chs)
29+
weights = scale * init(out_chs, in_chs, modes)
30+
31+
return SpectralConv1d(weights, in_chs, out_chs, modes, σ)
32+
end
33+
34+
Flux.@functor SpectralConv1d
35+
36+
function (m::SpectralConv1d)(𝐱::AbstractArray)
37+
𝐱_fft = fft(𝐱, 2) # [in_chs, x, batch]
38+
𝐱_selected = 𝐱_fft[:, 1:m.modes, :] # [in_chs, modes, batch]
39+
40+
# [out_chs, modes, batch] <- [in_chs, modes, batch] [out_chs, in_chs, modes]
41+
@tullio 𝐱_weighted[o, m, b] := 𝐱_selected[i, m, b] * m.weight[o, i, m]
42+
43+
s = size(𝐱_weighted)
44+
d = size(𝐱, 2) - m.modes
45+
𝐱_padded = cat(𝐱_weighted, zeros(ComplexF32, s[1], d, s[3:end]...), dims=2)
46+
47+
𝐱_out = ifft(𝐱_padded, 2)
48+
49+
return m.σ.(𝐱_out)
50+
end
51+
52+
function FourierOperator(
53+
ch::Pair{<:Integer,<:Integer},
54+
modes::Integer,
55+
σ=identity
56+
)
57+
return Chain(
58+
Parallel(+,
59+
Dense(ch.first, ch.second, init=c_glorot_uniform),
60+
SpectralConv1d(ch, modes)
61+
),
62+
x -> σ.(x)
63+
)
64+
end
65+
66+
function FNO()
67+
modes = 16
68+
ch = 64 => 64
69+
σ = x -> @. log(1 + exp(x))
70+
71+
return Chain(
72+
Dense(2, 64, init=c_glorot_uniform),
73+
FourierOperator(ch, modes, σ),
74+
FourierOperator(ch, modes, σ),
75+
FourierOperator(ch, modes, σ),
76+
FourierOperator(ch, modes),
77+
Dense(64, 128, σ, init=c_glorot_uniform),
78+
Dense(128, 1, init=c_glorot_uniform),
79+
flatten
80+
)
81+
end

src/preprocess.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using DataDeps
2+
using Fetch
3+
using MAT
4+
5+
export
6+
get_data
7+
8+
function register_datasets()
9+
register(DataDep(
10+
"BurgersR10",
11+
"""
12+
Burgers' equation dataset from
13+
[fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator)
14+
""",
15+
"https://drive.google.com/file/d/16a8od4vidbiNR3WtaBPCSZ0T3moxjhYe/view?usp=sharing",
16+
"9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd",
17+
fetch_method=gdownload,
18+
post_fetch_method=unpack
19+
))
20+
end
21+
22+
function get_data(; n=1000, Δsamples=2^3, grid_size=div(2^13, Δsamples))
23+
file = matopen(joinpath(datadep"BurgersR10", "burgers_data_R10.mat"))
24+
x_data = collect(read(file, "a")[1:n, 1:Δsamples:end]')
25+
y_data = collect(read(file, "u")[1:n, 1:Δsamples:end]')
26+
close(file)
27+
28+
x_loc_data = Array{Float32, 3}(undef, 2, grid_size, n)
29+
x_loc_data[1, :, :] .= reshape(repeat(LinRange(0, 1, grid_size), n), (grid_size, n))
30+
x_loc_data[2, :, :] .= x_data
31+
32+
return x_loc_data, y_data
33+
end

test/fourier.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
using Flux
2+
3+
@testset "SpectralConv1d" begin
4+
modes = 16
5+
ch = 64 => 64
6+
7+
m = Chain(
8+
Dense(2, 64, init=NeuralOperators.c_glorot_uniform),
9+
SpectralConv1d(ch, modes)
10+
)
11+
12+
𝐱, _ = get_data()
13+
@test size(m(𝐱)) == (64, 1024, 1000)
14+
15+
T = Float32
16+
loss(x, y) = Flux.mse(real.(m(x)), y)
17+
data = [(T.(𝐱[:, :, 1:5]), rand(T, 64, 1024, 5))]
18+
Flux.train!(loss, params(m), data, Flux.ADAM())
19+
end
20+
21+
@testset "FourierOperator" begin
22+
modes = 16
23+
ch = 64 => 64
24+
25+
m = Chain(
26+
Dense(2, 64, init=NeuralOperators.c_glorot_uniform),
27+
FourierOperator(ch, modes)
28+
)
29+
30+
𝐱, _ = get_data()
31+
@test size(m(𝐱)) == (64, 1024, 1000)
32+
33+
T = Float32
34+
loss(x, y) = Flux.mse(real.(m(x)), y)
35+
data = [(T.(𝐱[:, :, 1:5]), rand(T, 64, 1024, 5))]
36+
Flux.train!(loss, params(m), data, Flux.ADAM())
37+
end
38+
39+
@testset "FNO" begin
40+
𝐱, 𝐲 = get_data()
41+
𝐱, 𝐲 = Float32.(𝐱), Float32.(𝐲)
42+
@test size(FNO()(𝐱)) == size(𝐲)
43+
44+
m = FNO()
45+
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
46+
data = [(𝐱[:, :, 1:5], 𝐲[:, 1:5])]
47+
Flux.train!(loss, params(m), data, Flux.ADAM())
48+
end

test/preprocess.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
@testset "get data" begin
2+
xs, ys = get_data()
3+
4+
@test size(xs) == (2, 1024, 1000)
5+
@test size(ys) == (1024, 1000)
6+
end

test/runtests.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
using NeuralOperators
22
using Test
33

4+
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
5+
46
@testset "NeuralOperators.jl" begin
5-
# Write your tests here.
7+
include("preprocess.jl")
8+
include("fourier.jl")
69
end

0 commit comments

Comments
 (0)