1
- using CUDA, CUDAKernels, KernelAbstractions
2
-
3
1
export
4
2
SpectralConv1d,
5
3
FourierOperator,
@@ -17,9 +15,6 @@ function c_glorot_uniform(dims...)
17
15
return Flux. glorot_uniform (dims... ) + Flux. glorot_uniform (dims... ) * im
18
16
end
19
17
20
- t (𝐱) = @tullio 𝐱ᵀ[a, b, c] := 𝐱[b, a, c]
21
- ein_mul (𝐱₁, 𝐱₂) = @tullio 𝐲[m, o, b] := 𝐱₁[m, i, b] * 𝐱₂[o, i, m]
22
-
23
18
function SpectralConv1d (
24
19
ch:: Pair{<:Integer, <:Integer} ,
25
20
modes:: Integer ,
@@ -32,29 +27,32 @@ function SpectralConv1d(
32
27
weights = scale * init (out_chs, in_chs, modes)
33
28
34
29
return Chain (
35
- t,
36
30
x -> Zygote. hook (real, x),
37
31
SpectralConv1d (weights, in_chs, out_chs, modes, σ),
38
- t
39
32
)
40
33
end
41
34
42
35
Flux. @functor SpectralConv1d
43
36
37
+ t (𝐱) = @tullio 𝐱ᵀ[i, j, k] := 𝐱[j, i, k]
38
+ ein_mul (𝐱₁, 𝐱₂) = @tullio 𝐲[m, o, b] := 𝐱₁[m, i, b] * 𝐱₂[o, i, m]
39
+
44
40
function (m:: SpectralConv1d )(𝐱:: AbstractArray )
45
- 𝐱_fft = fft (𝐱, 1 ) # [x, in_chs, batch]
41
+ 𝐱ᵀ = t (𝐱) # [x, in_chs, batch] <- [in_chs, x, batch]
42
+ 𝐱_fft = fft (𝐱ᵀ, 1 ) # [x, in_chs, batch]
46
43
𝐱_selected = 𝐱_fft[1 : m. modes, :, :] # [modes, in_chs, batch]
47
44
48
- # [modes, out_chs, batch] <- [modes, in_chs, batch] [out_chs, in_chs, modes]
45
+ # [modes, out_chs, batch] <- [modes, in_chs, batch] * [out_chs, in_chs, modes]
49
46
𝐱_weighted = ein_mul (𝐱_selected, m. weight)
50
47
51
48
s = size (𝐱_weighted)[2 : end ]
52
- d = size (𝐱 , 1 ) - m. modes
49
+ d = size (𝐱ᵀ , 1 ) - m. modes
53
50
𝐱_padded = cat (𝐱_weighted, zeros (ComplexF32, d, s... ), dims= 1 )
54
51
55
52
𝐱_out = ifft (𝐱_padded, 1 ) # [x, out_chs, batch]
53
+ 𝐱_outᵀ = t (𝐱_out) # [out_chs, x, batch] <- [x, out_chs, batch]
56
54
57
- return m. σ .(real (𝐱_out ))
55
+ return m. σ .(real (𝐱_outᵀ ))
58
56
end
59
57
60
58
function FourierOperator (
0 commit comments