Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit d8370f8

Browse files
committed
refactor SpectralConv1d
1 parent defb9ca commit d8370f8

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

src/NeuralOperators.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ module NeuralOperators
66
using Flux
77
using FFTW
88
using Tullio
9+
using CUDA
10+
using CUDAKernels
11+
using KernelAbstractions
912
using Zygote
1013

1114
function __init__()

src/fourier.jl

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using CUDA, CUDAKernels, KernelAbstractions
2-
31
export
42
SpectralConv1d,
53
FourierOperator,
@@ -17,9 +15,6 @@ function c_glorot_uniform(dims...)
1715
return Flux.glorot_uniform(dims...) + Flux.glorot_uniform(dims...) * im
1816
end
1917

20-
t(𝐱) = @tullio 𝐱ᵀ[a, b, c] := 𝐱[b, a, c]
21-
ein_mul(𝐱₁, 𝐱₂) = @tullio 𝐲[m, o, b] := 𝐱₁[m, i, b] * 𝐱₂[o, i, m]
22-
2318
function SpectralConv1d(
2419
ch::Pair{<:Integer, <:Integer},
2520
modes::Integer,
@@ -32,29 +27,32 @@ function SpectralConv1d(
3227
weights = scale * init(out_chs, in_chs, modes)
3328

3429
return Chain(
35-
t,
3630
x -> Zygote.hook(real, x),
3731
SpectralConv1d(weights, in_chs, out_chs, modes, σ),
38-
t
3932
)
4033
end
4134

4235
Flux.@functor SpectralConv1d
4336

37+
t(𝐱) = @tullio 𝐱ᵀ[i, j, k] := 𝐱[j, i, k]
38+
ein_mul(𝐱₁, 𝐱₂) = @tullio 𝐲[m, o, b] := 𝐱₁[m, i, b] * 𝐱₂[o, i, m]
39+
4440
function (m::SpectralConv1d)(𝐱::AbstractArray)
45-
𝐱_fft = fft(𝐱, 1) # [x, in_chs, batch]
41+
𝐱ᵀ = t(𝐱) # [x, in_chs, batch] <- [in_chs, x, batch]
42+
𝐱_fft = fft(𝐱ᵀ, 1) # [x, in_chs, batch]
4643
𝐱_selected = 𝐱_fft[1:m.modes, :, :] # [modes, in_chs, batch]
4744

48-
# [modes, out_chs, batch] <- [modes, in_chs, batch] [out_chs, in_chs, modes]
45+
# [modes, out_chs, batch] <- [modes, in_chs, batch] * [out_chs, in_chs, modes]
4946
𝐱_weighted = ein_mul(𝐱_selected, m.weight)
5047

5148
s = size(𝐱_weighted)[2:end]
52-
d = size(𝐱, 1) - m.modes
49+
d = size(𝐱ᵀ, 1) - m.modes
5350
𝐱_padded = cat(𝐱_weighted, zeros(ComplexF32, d, s...), dims=1)
5451

5552
𝐱_out = ifft(𝐱_padded, 1) # [x, out_chs, batch]
53+
𝐱_outᵀ = t(𝐱_out) # [out_chs, x, batch] <- [x, out_chs, batch]
5654

57-
return m.σ.(real(𝐱_out))
55+
return m.σ.(real(𝐱_outᵀ))
5856
end
5957

6058
function FourierOperator(

0 commit comments

Comments
 (0)