|
1 | 1 | export
|
2 |
| - SpectralConv1d, |
| 2 | + SpectralConv, |
3 | 3 | FourierOperator
|
4 | 4 |
|
5 |
| -struct SpectralConv1d{T, S} |
| 5 | +struct SpectralConv{N, T, S} |
6 | 6 | weight::T
|
7 | 7 | in_channel::S
|
8 | 8 | out_channel::S
|
9 |
| - modes::S |
| 9 | + modes::NTuple{N, S} |
| 10 | + ndim::S |
10 | 11 | σ
|
11 | 12 | end
|
12 | 13 |
|
13 | 14 | c_glorot_uniform(dims...) = Flux.glorot_uniform(dims...) + Flux.glorot_uniform(dims...)*im
|
14 | 15 |
|
15 |
| -function SpectralConv1d( |
16 |
| - ch::Pair{<:Integer, <:Integer}, |
17 |
| - modes::Integer, |
| 16 | +function SpectralConv( |
| 17 | + ch::Pair{S, S}, |
| 18 | + modes::NTuple{N, S}, |
18 | 19 | σ=identity;
|
19 | 20 | init=c_glorot_uniform,
|
20 | 21 | T::DataType=ComplexF32
|
21 |
| -) |
| 22 | +) where {S<:Integer, N} |
22 | 23 | in_chs, out_chs = ch
|
23 | 24 | scale = one(T) / (in_chs * out_chs)
|
24 |
| - weights = scale * init(out_chs, in_chs, modes) |
| 25 | + weights = scale * init(out_chs, in_chs, prod(modes)) |
25 | 26 |
|
26 |
| - return SpectralConv1d(weights, in_chs, out_chs, modes, σ) |
| 27 | + return SpectralConv(weights, in_chs, out_chs, modes, N, σ) |
27 | 28 | end
|
28 | 29 |
|
29 |
| -Flux.@functor SpectralConv1d |
| 30 | +Flux.@functor SpectralConv |
30 | 31 |
|
| 32 | +Base.ndims(::SpectralConv{N}) where {N} = N |
| 33 | + |
| 34 | +# [prod(m.modes), out_chs, batch] <- [prod(m.modes), in_chs, batch] * [out_chs, in_chs, prod(m.modes)] |
31 | 35 | spectral_conv(𝐱₁, 𝐱₂) = @tullio 𝐲[m, o, b] := 𝐱₁[m, i, b] * 𝐱₂[o, i, m]
|
32 | 36 |
|
33 |
| -function (m::SpectralConv1d)(𝐱::AbstractArray) |
34 |
| - 𝐱ᵀ = permutedims(Zygote.hook(real, 𝐱), (2, 1, 3)) # [x, in_chs, batch] <- [in_chs, x, batch] |
35 |
| - 𝐱_fft = fft(𝐱ᵀ, 1) # [x, in_chs, batch] |
| 37 | +function (m::SpectralConv)(𝐱::AbstractArray) |
| 38 | + n_dims = ndims(𝐱) |
| 39 | + |
| 40 | + 𝐱ᵀ = permutedims(Zygote.hook(real, 𝐱), (ntuple(i->i+1, ndims(m))..., 1, ndims(m)+2)) # [x, in_chs, batch] <- [in_chs, x, batch] |
| 41 | + 𝐱_fft = fft(𝐱ᵀ, 1:ndims(m)) # [x, in_chs, batch] |
| 42 | + |
| 43 | + 𝐱_flattened = reshape(view(𝐱_fft, map(d->1:d, m.modes)..., :, :), :, size(𝐱_fft, n_dims-1), size(𝐱_fft, n_dims)) |
| 44 | + 𝐱_weighted = spectral_conv(𝐱_flattened, m.weight) # [prod(m.modes), out_chs, batch], only 3-dims |
| 45 | + 𝐱_shaped = reshape(𝐱_weighted, m.modes..., size(𝐱_weighted, 2), size(𝐱_weighted, 3)) |
36 | 46 |
|
37 |
| - # [modes, out_chs, batch] <- [modes, in_chs, batch] * [out_chs, in_chs, modes] |
38 |
| - 𝐱_weighted = spectral_conv(view(𝐱_fft, 1:m.modes, :, :), m.weight) |
39 | 47 | # [x, out_chs, batch] <- [modes, out_chs, batch]
|
40 |
| - 𝐱_padded = cat(𝐱_weighted, zeros(ComplexF32, size(𝐱_fft, 1)-m.modes, Base.tail(size(𝐱_weighted))...), dims=1) |
| 48 | + pad = zeros(ComplexF32, ntuple(i->size(𝐱_fft, i)-m.modes[i], ndims(m))..., size(𝐱_shaped, n_dims-1), size(𝐱_shaped, n_dims)) |
| 49 | + 𝐱_padded = cat(𝐱_shaped, pad, dims=1:ndims(m)) |
41 | 50 |
|
42 |
| - 𝐱_out = ifft(𝐱_padded, 1) # [x, out_chs, batch] |
43 |
| - 𝐱_outᵀ = permutedims(real(𝐱_out), (2, 1, 3)) # [out_chs, x, batch] <- [x, out_chs, batch] |
| 51 | + 𝐱_out = ifft(𝐱_padded, 1:ndims(m)) # [x, out_chs, batch] |
| 52 | + 𝐱_outᵀ = permutedims(real(𝐱_out), (ndims(m)+1, 1:ndims(m)..., ndims(m)+2)) # [out_chs, x, batch] <- [x, out_chs, batch] |
44 | 53 |
|
45 | 54 | return m.σ.(𝐱_outᵀ)
|
46 | 55 | end
|
47 | 56 |
|
48 |
| -function FourierOperator(ch::Pair{<:Integer, <:Integer}, modes::Integer, σ=identity) |
| 57 | +function FourierOperator(ch::Pair{S, S}, modes::NTuple{N, S}, σ=identity) where {S<:Integer, N} |
49 | 58 | return Chain(
|
50 |
| - Parallel(+, Dense(ch.first, ch.second), SpectralConv1d(ch, modes)), |
| 59 | + Parallel(+, Dense(ch.first, ch.second), SpectralConv(ch, modes)), |
51 | 60 | x -> σ.(x)
|
52 | 61 | )
|
53 | 62 | end
|
0 commit comments