1
1
export
2
2
SpectralConv,
3
+ SpectralConvPerm,
3
4
FourierOperator
4
5
5
- struct SpectralConv{N, T, S, F}
6
+ abstract type AbstractSpectralConv{N, T, S, F} end
7
+
8
+ struct SpectralConv{N, T, S, F} <: AbstractSpectralConv{N, T, S, F}
6
9
weight:: T
7
10
in_channel:: S
8
11
out_channel:: S
9
12
modes:: NTuple{N, S}
10
13
σ:: F
11
14
end
12
15
13
- struct SpectralConvPerm{N, T, S, F}
16
+ struct SpectralConvPerm{N, T, S, F} <: AbstractSpectralConv{N, T, S, F}
14
17
weight:: T
15
18
in_channel:: S
16
19
out_channel:: S
17
20
modes:: NTuple{N, S}
18
21
σ:: F
19
22
end
20
23
24
+ """
25
+ SpectralConv(
26
+ ch, modes, σ=identity;
27
+ init=c_glorot_uniform, permuted=false, T=ComplexF32
28
+ )
29
+
30
+ ## SpectralConv
31
+
32
+ * ``v(x)``: input
33
+ * ``F``, ``F^{-1}``: Fourier transform, inverse fourier transform
34
+ * ``L``: linear transform on the lower Fouier modes.
35
+
36
+ ``v(x)`` -> ``F`` -> ``L`` -> ``F^{-1}``
37
+
38
+ ## Example
39
+
40
+ ```jldoctest
41
+ julia> SpectralConv(2=>5, (16, ))
42
+ SpectralConv(2 => 5, (16,), σ=identity)
43
+
44
+ julia> using Flux
45
+
46
+ julia> SpectralConv(2=>5, (16, ), relu)
47
+ SpectralConv(2 => 5, (16,), σ=relu)
48
+
49
+ julia> SpectralConv(2=>5, (16, ), relu, permuted=true)
50
+ SpectralConvPerm(2 => 5, (16,), σ=relu)
51
+ ```
52
+ """
21
53
function SpectralConv (
22
54
ch:: Pair{S, S} ,
23
55
modes:: NTuple{N, S} ,
38
70
Flux. @functor SpectralConv
39
71
Flux. @functor SpectralConvPerm
40
72
41
- Base. ndims (:: SpectralConv{N} ) where {N} = N
42
- Base. ndims (:: SpectralConvPerm{N} ) where {N} = N
73
+ Base. ndims (:: AbstractSpectralConv{N} ) where {N} = N
74
+
75
+ function Base. show (io:: IO , l:: AbstractSpectralConv )
76
+ T = (l isa SpectralConv) ? SpectralConv : SpectralConvPerm
77
+ print (io, " $(string (T)) ($(l. in_channel) => $(l. out_channel) , $(l. modes) , σ=$(string (l. σ)) )" )
78
+ end
43
79
44
- function spectral_conv (m, 𝐱)
80
+ function spectral_conv (m:: AbstractSpectralConv , 𝐱:: AbstractArray )
45
81
n_dims = ndims (𝐱)
46
82
47
83
𝐱_fft = fft (Zygote. hook (real, 𝐱), 1 : ndims (m)) # [x, in_chs, batch]
@@ -54,33 +90,86 @@ function spectral_conv(m, 𝐱)
54
90
return m. σ .(𝐱_ifft)
55
91
end
56
92
57
- function (m:: SpectralConv )(𝐱:: AbstractArray )
93
+ function (m:: SpectralConv )(𝐱)
58
94
𝐱ᵀ = permutedims (𝐱, (ntuple (i-> i+ 1 , ndims (m))... , 1 , ndims (m)+ 2 )) # [x, in_chs, batch] <- [in_chs, x, batch]
59
95
𝐱_out = spectral_conv (m, 𝐱ᵀ) # [x, out_chs, batch]
60
96
𝐱_outᵀ = permutedims (𝐱_out, (ndims (m)+ 1 , 1 : ndims (m)... , ndims (m)+ 2 )) # [out_chs, x, batch] <- [x, out_chs, batch]
61
97
62
98
return 𝐱_outᵀ
63
99
end
64
100
65
- function (m:: SpectralConvPerm )(𝐱:: AbstractArray )
101
+ function (m:: SpectralConvPerm )(𝐱)
66
102
return spectral_conv (m, 𝐱) # [x, out_chs, batch]
67
103
end
68
104
69
105
# ###########
70
106
# operator #
71
107
# ###########
72
108
109
+ """
110
+ FourierOperator(ch, modes, σ=identity; permuted=false)
111
+
112
+ ## FourierOperator
113
+
114
+ * ``v(x)``: input
115
+ * ``F``, ``F^{-1}``: Fourier transform, inverse fourier transform
116
+ * ``L``: linear transform on the lower Fouier modes
117
+ * ``D``: local linear transform
118
+
119
+ ```
120
+ ┌ F -> L -> F¯¹ ┐
121
+ v(x) -> ┤ ├ -> + -> σ
122
+ └ D ┘
123
+ ```
124
+
125
+ ## Example
126
+
127
+ ```jldoctest
128
+ julia> FourierOperator(2=>5, (16, ))
129
+ Chain(
130
+ Parallel(
131
+ +,
132
+ Dense(2, 5), # 15 parameters
133
+ SpectralConv(2 => 5, (16,), σ=identity), # 160 parameters
134
+ ),
135
+ NeuralOperators.var"#activation_func#14"{typeof(identity)}(identity),
136
+ ) # Total: 3 arrays, 175 parameters, 1.668 KiB.
137
+
138
+ julia> using Flux
139
+
140
+ julia> FourierOperator(2=>5, (16, ), relu)
141
+ Chain(
142
+ Parallel(
143
+ +,
144
+ Dense(2, 5), # 15 parameters
145
+ SpectralConv(2 => 5, (16,), σ=identity), # 160 parameters
146
+ ),
147
+ NeuralOperators.var"#activation_func#14"{typeof(relu)}(NNlib.relu),
148
+ ) # Total: 3 arrays, 175 parameters, 1.668 KiB.
149
+
150
+ julia> FourierOperator(2=>5, (16, ), relu, permuted=true)
151
+ Chain(
152
+ Parallel(
153
+ +,
154
+ Conv((1,), 2 => 5), # 15 parameters
155
+ SpectralConvPerm(2 => 5, (16,), σ=identity), # 160 parameters
156
+ ),
157
+ NeuralOperators.var"#activation_func#14"{typeof(relu)}(NNlib.relu),
158
+ ) # Total: 3 arrays, 175 parameters, 1.871 KiB.
159
+ ```
160
+ """
73
161
function FourierOperator (
74
162
ch:: Pair{S, S} ,
75
163
modes:: NTuple{N, S} ,
76
164
σ= identity;
77
165
permuted= false
78
166
) where {S<: Integer , N}
79
167
short_cut = permuted ? Conv (Tuple (ones (Int, length (modes))), ch) : Dense (ch. first, ch. second)
168
+ activation_func (x) = σ .(x)
80
169
81
170
return Chain (
82
171
Parallel (+ , short_cut, SpectralConv (ch, modes, permuted= permuted)),
83
- x -> σ .(x)
172
+ activation_func
84
173
)
85
174
end
86
175
0 commit comments