1
1
export
2
2
SpectralConv1d,
3
- FourierOperator,
4
- FNO
3
+ FourierOperator
5
4
6
5
struct SpectralConv1d{T, S}
7
6
weight:: T
@@ -28,26 +27,31 @@ function SpectralConv1d(
28
27
29
28
return Chain (
30
29
x -> Zygote. hook (real, x),
31
- SpectralConv1d (weights, in_chs, out_chs, modes, σ)
30
+ SpectralConv1d (weights, in_chs, out_chs, modes, σ),
32
31
)
33
32
end
34
33
35
34
Flux. @functor SpectralConv1d
36
35
36
+ t (𝐱) = @tullio 𝐱ᵀ[i, j, k] := 𝐱[j, i, k]
37
+ ein_mul (𝐱₁, 𝐱₂) = @tullio 𝐲[m, o, b] := 𝐱₁[m, i, b] * 𝐱₂[o, i, m]
38
+
37
39
function (m:: SpectralConv1d )(𝐱:: AbstractArray )
38
- 𝐱_fft = fft (𝐱, 2 ) # [in_chs, x, batch]
39
- 𝐱_selected = 𝐱_fft[:, 1 : m. modes, :] # [in_chs, modes, batch]
40
+ 𝐱ᵀ = t (𝐱) # [x, in_chs, batch] <- [in_chs, x, batch]
41
+ 𝐱_fft = fft (𝐱ᵀ, 1 ) # [x, in_chs, batch]
42
+ 𝐱_selected = 𝐱_fft[1 : m. modes, :, :] # [modes, in_chs, batch]
40
43
41
- # [out_chs, modes , batch] <- [in_chs, modes , batch] [out_chs, in_chs, modes]
42
- @tullio 𝐱_weighted[o, m, b] := 𝐱_selected[i , m, b] * m . weight[o, i, m]
44
+ # [modes, out_chs , batch] <- [modes, in_chs , batch] * [out_chs, in_chs, modes]
45
+ 𝐱_weighted = ein_mul ( 𝐱_selected, m. weight)
43
46
44
- s = size (𝐱_weighted)
45
- d = size (𝐱, 2 ) - m. modes
46
- 𝐱_padded = cat (𝐱_weighted, zeros (ComplexF32, s[ 1 ], d, s[ 3 : end ] . .. ), dims= 2 )
47
+ s = size (𝐱_weighted)[ 2 : end ]
48
+ d = size (𝐱ᵀ, 1 ) - m. modes
49
+ 𝐱_padded = cat (𝐱_weighted, zeros (ComplexF32, d, s... ), dims= 1 )
47
50
48
- 𝐱_out = ifft (𝐱_padded, 2 )
51
+ 𝐱_out = ifft (𝐱_padded, 1 ) # [x, out_chs, batch]
52
+ 𝐱_outᵀ = t (𝐱_out) # [out_chs, x, batch] <- [x, out_chs, batch]
49
53
50
- return m. σ .(real (𝐱_out ))
54
+ return m. σ .(real (𝐱_outᵀ ))
51
55
end
52
56
53
57
function FourierOperator (
0 commit comments