1
+ using CUDA, CUDAKernels, KernelAbstractions
2
+
1
3
export
2
4
SpectralConv1d,
3
5
FourierOperator,
@@ -15,6 +17,9 @@ function c_glorot_uniform(dims...)
15
17
return Flux. glorot_uniform (dims... ) + Flux. glorot_uniform (dims... ) * im
16
18
end
17
19
20
+ t (𝐱) = @tullio 𝐱ᵀ[a, b, c] := 𝐱[b, a, c]
21
+ ein_mul (𝐱₁, 𝐱₂) = @tullio 𝐲[m, o, b] := 𝐱₁[m, i, b] * 𝐱₂[o, i, m]
22
+
18
23
function SpectralConv1d (
19
24
ch:: Pair{<:Integer, <:Integer} ,
20
25
modes:: Integer ,
@@ -27,25 +32,27 @@ function SpectralConv1d(
27
32
weights = scale * init (out_chs, in_chs, modes)
28
33
29
34
return Chain (
35
+ t,
30
36
x -> Zygote. hook (real, x),
31
- SpectralConv1d (weights, in_chs, out_chs, modes, σ)
37
+ SpectralConv1d (weights, in_chs, out_chs, modes, σ),
38
+ t
32
39
)
33
40
end
34
41
35
42
Flux. @functor SpectralConv1d
36
43
37
44
function (m:: SpectralConv1d )(𝐱:: AbstractArray )
38
- 𝐱_fft = fft (𝐱, 2 ) # [in_chs, x , batch]
39
- 𝐱_selected = 𝐱_fft[:, 1 : m. modes, :] # [in_chs, modes , batch]
45
+ 𝐱_fft = fft (𝐱, 1 ) # [x, in_chs , batch]
46
+ 𝐱_selected = 𝐱_fft[1 : m. modes, :, : ] # [modes, in_chs , batch]
40
47
41
- # [out_chs, modes , batch] <- [in_chs, modes , batch] [out_chs, in_chs, modes]
42
- @tullio 𝐱_weighted[o, m, b] := 𝐱_selected[i , m, b] * m . weight[o, i, m]
48
+ # [modes, out_chs , batch] <- [modes, in_chs , batch] [out_chs, in_chs, modes]
49
+ 𝐱_weighted = ein_mul ( 𝐱_selected, m. weight)
43
50
44
- s = size (𝐱_weighted)
45
- d = size (𝐱, 2 ) - m. modes
46
- 𝐱_padded = cat (𝐱_weighted, zeros (ComplexF32, s[ 1 ], d, s[ 3 : end ] . .. ), dims= 2 )
51
+ s = size (𝐱_weighted)[ 2 : end ]
52
+ d = size (𝐱, 1 ) - m. modes
53
+ 𝐱_padded = cat (𝐱_weighted, zeros (ComplexF32, d, s... ), dims= 1 )
47
54
48
- 𝐱_out = ifft (𝐱_padded, 2 )
55
+ 𝐱_out = ifft (𝐱_padded, 1 ) # [x, out_chs, batch]
49
56
50
57
return m. σ .(real (𝐱_out))
51
58
end
0 commit comments