Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Fetch = "bb354801-46f6-40b6-9c3d-d42d7a74c775"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand All @@ -23,6 +24,7 @@ FFTW = "1.4"
Flux = "0.12"
KernelAbstractions = "0.7"
MAT = "0.10"
StatsBase = "0.33"
Tullio = "0.3"
Zygote = "0.6"
julia = "1.6"
Expand Down
1 change: 1 addition & 0 deletions src/NeuralOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module NeuralOperators
using DataDeps
using Fetch
using MAT
using StatsBase

using Flux
using FFTW
Expand Down
64 changes: 60 additions & 4 deletions src/data.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,29 @@
export
get_burgers_data
UnitGaussianNormalizer,
encode,
decode,
get_burgers_data,
get_darcy_flow_data

function register_datasets()
struct UnitGaussianNormalizer{T}
mean::Array{T}
std::Array{T}
ϵ::T
end

function UnitGaussianNormalizer(𝐱; ϵ=1f-5)
dims = 1:length(size(𝐱))-1

return UnitGaussianNormalizer(mean(𝐱, dims=dims), StatsBase.std(𝐱, dims=dims), ϵ)
end

encode(n::UnitGaussianNormalizer, 𝐱::AbstractArray) = @. (𝐱-n.mean) / (n.std+n.ϵ)
decode(n::UnitGaussianNormalizer, 𝐱::AbstractArray) = @. 𝐱 * (n.std+n.ϵ) + n.mean


function register_burgers()
register(DataDep(
"BurgersR10",
"Burgers",
"""
Burgers' equation dataset from
[fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator)
Expand All @@ -15,8 +35,27 @@ function register_datasets()
))
end

function register_darcy_flow()
register(DataDep(
"DarcyFlow",
"""
Darcy flow dataset from
[fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator)
""",
"https://drive.google.com/file/d/1Z1uxG9R8AdAGJprG5STcphysjm56_0Jf/view?usp=sharing",
"802825de9da7398407296c99ca9ceb2371c752f6a3bdd1801172e02ce19edda4",
fetch_method=gdownload,
post_fetch_method=unpack
))
end

function register_datasets()
register_burgers()
register_darcy_flow()
end

function get_burgers_data(; n=2048, Δsamples=2^3, grid_size=div(2^13, Δsamples), T=Float32)
file = matopen(joinpath(datadep"BurgersR10", "burgers_data_R10.mat"))
file = matopen(joinpath(datadep"Burgers", "burgers_data_R10.mat"))
x_data = T.(collect(read(file, "a")[1:n, 1:Δsamples:end]'))
y_data = T.(collect(read(file, "u")[1:n, 1:Δsamples:end]'))
close(file)
Expand All @@ -27,3 +66,20 @@ function get_burgers_data(; n=2048, Δsamples=2^3, grid_size=div(2^13, Δsamples

return x_loc_data, y_data
end

function get_darcy_flow_data(; n=1024, Δsamples=5, T=Float32, test_data=true)
# size(training_data) == size(testing_data) == (1024, 421, 421)
file = test_data ? "piececonst_r421_N1024_smooth2.mat" : "piececonst_r421_N1024_smooth1.mat"
file = matopen(joinpath(datadep"DarcyFlow", file))
x_data = T.(permutedims(read(file, "coeff")[1:n, 1:Δsamples:end, 1:Δsamples:end], (3, 2, 1)))
y_data = T.(permutedims(read(file, "sol")[1:n, 1:Δsamples:end, 1:Δsamples:end], (3, 2, 1)))
close(file)

x_dims = pushfirst!([size(x_data)...], 1)
y_dims = pushfirst!([size(y_data)...], 1)
x_data, y_data = reshape(x_data, x_dims...), reshape(y_data, y_dims...)

x_normalizer, y_normalizer = UnitGaussianNormalizer(x_data), UnitGaussianNormalizer(y_data)

return encode(x_normalizer, x_data), encode(y_normalizer, y_data), x_normalizer, y_normalizer
end
46 changes: 27 additions & 19 deletions src/fourier.jl
Original file line number Diff line number Diff line change
@@ -1,53 +1,61 @@
export
SpectralConv1d,
SpectralConv,
FourierOperator

struct SpectralConv1d{T, S}
struct SpectralConv{N, T, S}
weight::T
in_channel::S
out_channel::S
modes::S
modes::NTuple{N, S}
ndim::S
σ
end

c_glorot_uniform(dims...) = Flux.glorot_uniform(dims...) + Flux.glorot_uniform(dims...)*im

function SpectralConv1d(
ch::Pair{<:Integer, <:Integer},
modes::Integer,
function SpectralConv(
ch::Pair{S, S},
modes::NTuple{N, S},
σ=identity;
init=c_glorot_uniform,
T::DataType=ComplexF32
)
) where {S<:Integer, N}
in_chs, out_chs = ch
scale = one(T) / (in_chs * out_chs)
weights = scale * init(out_chs, in_chs, modes)
weights = scale * init(out_chs, in_chs, prod(modes))

return SpectralConv1d(weights, in_chs, out_chs, modes, σ)
return SpectralConv(weights, in_chs, out_chs, modes, N, σ)
end

Flux.@functor SpectralConv1d
Flux.@functor SpectralConv

Base.ndims(::SpectralConv{N}) where {N} = N

spectral_conv(𝐱₁, 𝐱₂) = @tullio 𝐲[m, o, b] := 𝐱₁[m, i, b] * 𝐱₂[o, i, m]

function (m::SpectralConv1d)(𝐱::AbstractArray)
𝐱ᵀ = permutedims(Zygote.hook(real, 𝐱), (2, 1, 3)) # [x, in_chs, batch] <- [in_chs, x, batch]
𝐱_fft = fft(𝐱ᵀ, 1) # [x, in_chs, batch]
function (m::SpectralConv)(𝐱::AbstractArray)
𝐱ᵀ = permutedims(Zygote.hook(real, 𝐱), (2:ndims(m)+1..., 1, ndims(m)+2)) # [x, in_chs, batch] <- [in_chs, x, batch]
𝐱_fft = fft(𝐱ᵀ, 1:ndims(m)) # [x, in_chs, batch]

# [modes, out_chs, batch] <- [modes, in_chs, batch] * [out_chs, in_chs, modes]
𝐱_weighted = spectral_conv(view(𝐱_fft, 1:m.modes, :, :), m.weight)
ranges = [1:dim_modes for dim_modes in m.modes]
𝐱_flattened = reshape(view(𝐱_fft, ranges..., :, :), prod(m.modes), size(𝐱_fft)[end-1:end]...)
𝐱_weighted = spectral_conv(𝐱_flattened, m.weight)
𝐱_shaped = reshape(𝐱_weighted, m.modes..., size(𝐱_weighted)[end-1:end]...)

# [x, out_chs, batch] <- [modes, out_chs, batch]
𝐱_padded = cat(𝐱_weighted, zeros(ComplexF32, size(𝐱_fft, 1)-m.modes, Base.tail(size(𝐱_weighted))...), dims=1)
pad = zeros(ComplexF32, (collect(size(𝐱_fft)[1:ndims(m)])-collect(m.modes))..., size(𝐱_shaped)[end-1:end]...)
𝐱_padded = cat(𝐱_shaped, pad, dims=1:ndims(m))

𝐱_out = ifft(𝐱_padded, 1) # [x, out_chs, batch]
𝐱_outᵀ = permutedims(real(𝐱_out), (2, 1, 3)) # [out_chs, x, batch] <- [x, out_chs, batch]
𝐱_out = ifft(𝐱_padded, 1:ndims(m)) # [x, out_chs, batch]
𝐱_outᵀ = permutedims(real(𝐱_out), (ndims(m)+1, 1:ndims(m)..., ndims(m)+2)) # [out_chs, x, batch] <- [x, out_chs, batch]

return m.σ.(𝐱_outᵀ)
end

function FourierOperator(ch::Pair{<:Integer, <:Integer}, modes::Integer, σ=identity)
function FourierOperator(ch::Pair{S, S}, modes::NTuple{N, S}, σ=identity) where {S<:Integer, N}
return Chain(
Parallel(+, Dense(ch.first, ch.second), SpectralConv1d(ch, modes)),
Parallel(+, Dense(ch.first, ch.second), SpectralConv(ch, modes)),
x -> σ.(x)
)
end
2 changes: 1 addition & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ export
FourierNeuralOperator

function FourierNeuralOperator()
modes = 16
modes = (16, )
ch = 64 => 64
σ = relu

Expand Down
18 changes: 18 additions & 0 deletions test/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,21 @@
@test size(xs) == (2, 1024, 1000)
@test size(ys) == (1024, 1000)
end

@testset "unit gaussian normalizer" begin
dims = (3, 3, 5, 6)
𝐱 = rand(Float32, dims)

n = UnitGaussianNormalizer(𝐱)

@test size(n.mean) == size(n.std)
@test size(encode(n, 𝐱)) == dims
@test size(decode(n, encode(n, 𝐱))) == dims
end

@testset "get darcy flow data" begin
xs, ys, _, _ = get_darcy_flow_data()

@test size(xs) == (1, 85, 85, 1024)
@test size(ys) == (1, 85, 85, 1024)
end
57 changes: 46 additions & 11 deletions test/fourier.jl
Original file line number Diff line number Diff line change
@@ -1,34 +1,69 @@
@testset "SpectralConv1d" begin
modes = 16
modes = (16, )
ch = 64 => 64

m = Chain(
Dense(2, 64),
SpectralConv1d(ch, modes)
SpectralConv(ch, modes)
)
@test ndims(SpectralConv(ch, modes)) == 1

𝐱, _ = get_burgers_data(n=1000)
@test size(m(𝐱)) == (64, 1024, 1000)
𝐱, _ = get_burgers_data(n=5)
@test size(m(𝐱)) == (64, 1024, 5)

T = Float32
loss(x, y) = Flux.mse(m(x), y)
data = [(T.(𝐱[:, :, 1:5]), rand(T, 64, 1024, 5))]
data = [(𝐱, rand(Float32, 64, 1024, 5))]
Flux.train!(loss, params(m), data, Flux.ADAM())
end

@testset "FourierOperator" begin
modes = 16
@testset "FourierOperator1d" begin
modes = (16, )
ch = 64 => 64

m = Chain(
Dense(2, 64),
FourierOperator(ch, modes)
)

𝐱, _ = get_burgers_data(n=1000)
@test size(m(𝐱)) == (64, 1024, 1000)
𝐱, _ = get_burgers_data(n=5)
@test size(m(𝐱)) == (64, 1024, 5)

loss(x, y) = Flux.mse(m(x), y)
data = [(Float32.(𝐱[:, :, 1:5]), rand(Float32, 64, 1024, 5))]
data = [(𝐱, rand(Float32, 64, 1024, 5))]
Flux.train!(loss, params(m), data, Flux.ADAM())
end

@testset "SpectralConv2d" begin
modes = (16, 16)
ch = 64 => 64

m = Chain(
Dense(1, 64),
SpectralConv(ch, modes)
)
@test ndims(SpectralConv(ch, modes)) == 2

𝐱, _, _, _ = get_darcy_flow_data(n=5, Δsamples=20)
@test size(m(𝐱)) == (64, 22, 22, 5)

loss(x, y) = Flux.mse(m(x), y)
data = [(𝐱, rand(Float32, 64, 22, 22, 5))]
Flux.train!(loss, params(m), data, Flux.ADAM())
end

@testset "FourierOperator2d" begin
modes = (16, 16)
ch = 64 => 64

m = Chain(
Dense(1, 64),
FourierOperator(ch, modes)
)

𝐱, _, _, _ = get_darcy_flow_data(n=5, Δsamples=20)
@test size(m(𝐱)) == (64, 22, 22, 5)

loss(x, y) = Flux.mse(m(x), y)
data = [(𝐱, rand(Float32, 64, 22, 22, 5))]
Flux.train!(loss, params(m), data, Flux.ADAM())
end