7
7
FourierOperator,
8
8
FNO
9
9
10
- struct SpectralConv1d{T,S}
10
+ c_glorot_uniform (dims... ) = Flux. glorot_uniform (dims... ) + Flux. glorot_uniform (dims... ) * im
11
+
12
+ struct SpectralConv1d{T, S}
11
13
weight:: T
12
14
in_channel:: S
13
15
out_channel:: S
@@ -19,8 +21,8 @@ function SpectralConv1d(
19
21
ch:: Pair{<:Integer,<:Integer} ,
20
22
modes:: Integer ,
21
23
σ= identity;
22
- init= Flux . glorot_uniform ,
23
- T:: DataType = Float32
24
+ init= c_glorot_uniform ,
25
+ T:: DataType = ComplexF32
24
26
)
25
27
in_chs, out_chs = ch
26
28
scale = one (T) / (in_chs * out_chs)
32
34
Flux. @functor SpectralConv1d
33
35
34
36
function (m:: SpectralConv1d )(𝐱:: AbstractArray )
35
- 𝐱_fft = rfft (𝐱, 1 ) # [x, in_chs , batch]
36
- 𝐱_selected = 𝐱_fft[1 : m. modes, :, : ] # [modes, in_chs , batch]
37
+ 𝐱_fft = fft (𝐱, 2 ) # [in_chs, x , batch]
38
+ 𝐱_selected = 𝐱_fft[:, 1 : m. modes, :] # [in_chs, modes , batch]
37
39
38
- # [modes, out_chs , batch] <- [modes, in_chs , batch] [out_chs, in_chs, modes]
39
- @tullio 𝐱_weighted[m, o , b] := 𝐱_selected[m, i , b] * m. weight[o, i, m]
40
+ # [out_chs, modes , batch] <- [in_chs, modes , batch] [out_chs, in_chs, modes]
41
+ @tullio 𝐱_weighted[o, m , b] := 𝐱_selected[i, m , b] * m. weight[o, i, m]
40
42
41
- d = size (𝐱, 1 ) ÷ 2 + 1 - m. modes
42
- 𝐱_padded = cat (𝐱_weighted, zeros (Float32, d, size (𝐱)[2 : end ]. .. ), dims= 1 )
43
+ s = size (𝐱_weighted)
44
+ d = size (𝐱, 2 ) - m. modes
45
+ 𝐱_padded = cat (𝐱_weighted, zeros (ComplexF32, s[1 ], d, s[3 : end ]. .. ), dims= 2 )
43
46
44
- 𝐱_out = irfft (𝐱_padded , size (𝐱, 1 ), 1 )
47
+ 𝐱_out = ifft (𝐱_padded, 2 )
45
48
46
49
return m. σ .(𝐱_out)
47
50
end
@@ -53,7 +56,7 @@ function FourierOperator(
53
56
)
54
57
return Chain (
55
58
Parallel (+ ,
56
- Conv (( 1 , ), ch ),
59
+ Dense (ch . first, ch . second, init = c_glorot_uniform ),
57
60
SpectralConv1d (ch, modes)
58
61
),
59
62
x -> σ .(x)
63
66
function FNO ()
64
67
modes = 16
65
68
ch = 64 => 64
69
+ σ = x -> @. log (1 + exp (x))
66
70
67
71
return Chain (
68
- Conv (( 1 , ), 2 => 64 ),
69
- FourierOperator (ch, modes, relu ),
70
- FourierOperator (ch, modes, relu ),
71
- FourierOperator (ch, modes, relu ),
72
+ Dense ( 2 , 64 , init = c_glorot_uniform ),
73
+ FourierOperator (ch, modes, σ ),
74
+ FourierOperator (ch, modes, σ ),
75
+ FourierOperator (ch, modes, σ ),
72
76
FourierOperator (ch, modes),
73
- Conv (( 1 , ), 64 => 128 , relu ),
74
- Conv (( 1 , ), 128 => 1 ),
77
+ Dense ( 64 , 128 , σ, init = c_glorot_uniform ),
78
+ Dense ( 128 , 1 , init = c_glorot_uniform ),
75
79
flatten
76
80
)
77
81
end
0 commit comments