Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.
44 changes: 24 additions & 20 deletions src/fourier.jl
Original file line number Diff line number Diff line change
@@ -1,53 +1,57 @@
export
SpectralConv1d,
SpectralConv,
FourierOperator

struct SpectralConv1d{T, S}
struct SpectralConv{T, S, N}
weight::T
in_channel::S
out_channel::S
modes::S
modes::NTuple{N, S}
ndim::S
σ
end

c_glorot_uniform(dims...) = Flux.glorot_uniform(dims...) + Flux.glorot_uniform(dims...)*im

function SpectralConv1d(
ch::Pair{<:Integer, <:Integer},
modes::Integer,
function SpectralConv(
ch::Pair{S, S},
modes::NTuple{N, S},
σ=identity;
init=c_glorot_uniform,
T::DataType=ComplexF32
)
) where {S<:Integer, N}
in_chs, out_chs = ch
scale = one(T) / (in_chs * out_chs)
weights = scale * init(out_chs, in_chs, modes)
weights = scale * init(out_chs, in_chs, modes...)

return SpectralConv1d(weights, in_chs, out_chs, modes, σ)
return SpectralConv(weights, in_chs, out_chs, modes, N, σ)
end

Flux.@functor SpectralConv1d
Flux.@functor SpectralConv

spectral_conv(𝐱₁, 𝐱₂) = @tullio 𝐲[m, o, b] := 𝐱₁[m, i, b] * 𝐱₂[o, i, m]
spectral_conv(𝐱₁, 𝐱₂) = @tullio 𝐲[m, o, b] := 𝐱₁[m, i, b] * 𝐱₂[o, i, m] # TODO: extend `m` to n-dim

function (m::SpectralConv1d)(𝐱::AbstractArray)
𝐱ᵀ = permutedims(Zygote.hook(real, 𝐱), (2, 1, 3)) # [x, in_chs, batch] <- [in_chs, x, batch]
𝐱_fft = fft(𝐱ᵀ, 1) # [x, in_chs, batch]
function (m::SpectralConv)(𝐱::AbstractArray)
𝐱ᵀ = permutedims(Zygote.hook(real, 𝐱), (m.ndim+1, 1:m.ndim..., m.ndim+2)) # [x, in_chs, batch] <- [in_chs, x, batch]
𝐱_fft = fft(𝐱ᵀ, 1:m.ndim) # [x, in_chs, batch]

# [modes, out_chs, batch] <- [modes, in_chs, batch] * [out_chs, in_chs, modes]
𝐱_weighted = spectral_conv(view(𝐱_fft, 1:m.modes, :, :), m.weight)
ranges = [1:dim_modes for dim_modes in m.modes]
𝐱_weighted = spectral_conv(view(𝐱_fft, ranges..., :, :), m.weight)

# [x, out_chs, batch] <- [modes, out_chs, batch]
𝐱_padded = cat(𝐱_weighted, zeros(ComplexF32, size(𝐱_fft, 1)-m.modes, Base.tail(size(𝐱_weighted))...), dims=1)
pad = zeros(ComplexF32, (collect(size(𝐱_fft)[1:m.ndim])-collect(m.modes))..., size(𝐱_weighted)[end-1:end]...)
𝐱_padded = cat(𝐱_weighted, pad, dims=1:m.ndim)

𝐱_out = ifft(𝐱_padded, 1) # [x, out_chs, batch]
𝐱_outᵀ = permutedims(real(𝐱_out), (2, 1, 3)) # [out_chs, x, batch] <- [x, out_chs, batch]
𝐱_out = ifft(𝐱_padded, 1:m.ndim) # [x, out_chs, batch]
𝐱_outᵀ = permutedims(real(𝐱_out), (2:m.ndim+1..., 1, m.ndim+2)) # [out_chs, x, batch] <- [x, out_chs, batch]

return m.σ.(𝐱_outᵀ)
end

function FourierOperator(ch::Pair{<:Integer, <:Integer}, modes::Integer, σ=identity)
function FourierOperator(ch::Pair{S, S}, modes::NTuple{N, S}, σ=identity) where {S<:Integer, N}
return Chain(
Parallel(+, Dense(ch.first, ch.second), SpectralConv1d(ch, modes)),
Parallel(+, Dense(ch.first, ch.second), SpectralConv(ch, modes)),
x -> σ.(x)
)
end
2 changes: 1 addition & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ export
FourierNeuralOperator

function FourierNeuralOperator()
modes = 16
modes = (16, )
ch = 64 => 64
σ = relu

Expand Down
8 changes: 4 additions & 4 deletions test/fourier.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
@testset "SpectralConv1d" begin
modes = 16
@testset "SpectralConv" begin
modes = (16, )
ch = 64 => 64

m = Chain(
Dense(2, 64),
SpectralConv1d(ch, modes)
SpectralConv(ch, modes)
)

𝐱, _ = get_burgers_data(n=1000)
Expand All @@ -17,7 +17,7 @@
end

@testset "FourierOperator" begin
modes = 16
modes = (16, )
ch = 64 => 64

m = Chain(
Expand Down