1
1
export
2
2
SpectralConv,
3
- SpectralConvPerm,
4
3
FourierOperator
5
4
6
- abstract type AbstractSpectralConv{N, T, S, F} end
7
-
8
- struct SpectralConv{N, T, S, F} <: AbstractSpectralConv{N, T, S, F}
9
- weight:: T
10
- in_channel:: S
11
- out_channel:: S
12
- modes:: NTuple{N, S}
13
- σ:: F
14
- end
15
-
16
- struct SpectralConvPerm{N, T, S, F} <: AbstractSpectralConv{N, T, S, F}
5
+ struct SpectralConv{P, N, T, S, F}
17
6
weight:: T
18
7
in_channel:: S
19
8
out_channel:: S
33
22
* `modes`: The Fourier modes to be preserved.
34
23
* `σ`: Activation function.
35
24
* `permuted`: Whether the dim is permuted. If `permuted=true`, layer accepts
36
- data in the order of `(..., ch, batch)`, otherwise the order is `(ch, ..., batch)`.
25
+ data in the order of `(ch, ..., batch)`, otherwise the order is `(..., ch , batch)`.
37
26
38
27
## Example
39
28
40
29
```jldoctest
41
30
julia> SpectralConv(2=>5, (16, ))
42
- SpectralConv(2 => 5, (16,), σ=identity)
31
+ SpectralConv(2 => 5, (16,), σ=identity, permuted=false )
43
32
44
33
julia> using Flux
45
34
46
35
julia> SpectralConv(2=>5, (16, ), relu)
47
- SpectralConv(2 => 5, (16,), σ=relu)
36
+ SpectralConv(2 => 5, (16,), σ=relu, permuted=false )
48
37
49
38
julia> SpectralConv(2=>5, (16, ), relu, permuted=true)
50
- SpectralConvPerm (2 => 5, (16,), σ=relu)
39
+ SpectralConv (2 => 5, (16,), σ=relu, permuted=true )
51
40
```
52
41
"""
53
42
function SpectralConv (
@@ -61,23 +50,23 @@ function SpectralConv(
61
50
in_chs, out_chs = ch
62
51
scale = one (T) / (in_chs * out_chs)
63
52
weights = scale * init (out_chs, in_chs, prod (modes))
53
+ W = typeof (weights)
54
+ F = typeof (σ)
64
55
65
- L = permuted ? SpectralConvPerm : SpectralConv
66
-
67
- return L (weights, in_chs, out_chs, modes, σ)
56
+ return SpectralConv {permuted,N,W,S,F} (weights, in_chs, out_chs, modes, σ)
68
57
end
69
58
70
59
Flux. @functor SpectralConv
71
- Flux. @functor SpectralConvPerm
72
60
73
- Base. ndims (:: AbstractSpectralConv{N} ) where {N} = N
61
+ Base. ndims (:: SpectralConv{P,N} ) where {P,N} = N
62
+
63
+ permuted (:: SpectralConv{P} ) where {P} = P
74
64
75
- function Base. show (io:: IO , l:: AbstractSpectralConv )
76
- T = (l isa SpectralConv) ? SpectralConv : SpectralConvPerm
77
- print (io, " $(string (T)) ($(l. in_channel) => $(l. out_channel) , $(l. modes) , σ=$(string (l. σ)) )" )
65
+ function Base. show (io:: IO , l:: SpectralConv{P} ) where {P}
66
+ print (io, " SpectralConv($(l. in_channel) => $(l. out_channel) , $(l. modes) , σ=$(string (l. σ)) , permuted=$P )" )
78
67
end
79
68
80
- function spectral_conv (m:: AbstractSpectralConv , 𝐱:: AbstractArray )
69
+ function spectral_conv (m:: SpectralConv , 𝐱:: AbstractArray )
81
70
n_dims = ndims (𝐱)
82
71
83
72
𝐱_fft = fft (Zygote. hook (real, 𝐱), 1 : ndims (m)) # [x, in_chs, batch]
@@ -90,22 +79,28 @@ function spectral_conv(m::AbstractSpectralConv, 𝐱::AbstractArray)
90
79
return m. σ .(𝐱_ifft)
91
80
end
92
81
93
- function (m:: SpectralConv )(𝐱)
82
+ function (m:: SpectralConv{false} )(𝐱)
94
83
𝐱ᵀ = permutedims (𝐱, (ntuple (i-> i+ 1 , ndims (m))... , 1 , ndims (m)+ 2 )) # [x, in_chs, batch] <- [in_chs, x, batch]
95
84
𝐱_out = spectral_conv (m, 𝐱ᵀ) # [x, out_chs, batch]
96
85
𝐱_outᵀ = permutedims (𝐱_out, (ndims (m)+ 1 , 1 : ndims (m)... , ndims (m)+ 2 )) # [out_chs, x, batch] <- [x, out_chs, batch]
97
86
98
87
return 𝐱_outᵀ
99
88
end
100
89
101
- function (m:: SpectralConvPerm )(𝐱)
90
+ function (m:: SpectralConv{true} )(𝐱)
102
91
return spectral_conv (m, 𝐱) # [x, out_chs, batch]
103
92
end
104
93
105
94
# ###########
106
95
# operator #
107
96
# ###########
108
97
98
+ struct FourierOperator{L, C, F}
99
+ linear:: L
100
+ conv:: C
101
+ σ:: F
102
+ end
103
+
109
104
"""
110
105
FourierOperator(ch, modes, σ=identity; permuted=false)
111
106
@@ -115,42 +110,21 @@ end
115
110
* `modes`: The Fourier modes to be preserved for spectral convolution.
116
111
* `σ`: Activation function.
117
112
* `permuted`: Whether the dim is permuted. If `permuted=true`, layer accepts
118
- data in the order of `(..., ch, batch)`, otherwise the order is `(ch, ..., batch)`.
113
+ data in the order of `(ch, ..., batch)`, otherwise the order is `(..., ch , batch)`.
119
114
120
115
## Example
121
116
122
117
```jldoctest
123
118
julia> FourierOperator(2=>5, (16, ))
124
- Chain(
125
- Parallel(
126
- +,
127
- Dense(2, 5), # 15 parameters
128
- SpectralConv(2 => 5, (16,), σ=identity), # 160 parameters
129
- ),
130
- NeuralOperators.var"#activation_func#14"{typeof(identity)}(identity),
131
- ) # Total: 3 arrays, 175 parameters, 1.668 KiB.
119
+ FourierOperator(2 => 5, (16,), σ=identity, permuted=false)
132
120
133
121
julia> using Flux
134
122
135
123
julia> FourierOperator(2=>5, (16, ), relu)
136
- Chain(
137
- Parallel(
138
- +,
139
- Dense(2, 5), # 15 parameters
140
- SpectralConv(2 => 5, (16,), σ=identity), # 160 parameters
141
- ),
142
- NeuralOperators.var"#activation_func#14"{typeof(relu)}(NNlib.relu),
143
- ) # Total: 3 arrays, 175 parameters, 1.668 KiB.
124
+ FourierOperator(2 => 5, (16,), σ=relu, permuted=false)
144
125
145
126
julia> FourierOperator(2=>5, (16, ), relu, permuted=true)
146
- Chain(
147
- Parallel(
148
- +,
149
- Conv((1,), 2 => 5), # 15 parameters
150
- SpectralConvPerm(2 => 5, (16,), σ=identity), # 160 parameters
151
- ),
152
- NeuralOperators.var"#activation_func#14"{typeof(relu)}(NNlib.relu),
153
- ) # Total: 3 arrays, 175 parameters, 1.871 KiB.
127
+ FourierOperator(2 => 5, (16,), σ=relu, permuted=true)
154
128
```
155
129
"""
156
130
function FourierOperator (
@@ -159,15 +133,23 @@ function FourierOperator(
159
133
σ= identity;
160
134
permuted= false
161
135
) where {S<: Integer , N}
162
- short_cut = permuted ? Conv (Tuple (ones (Int, length (modes))), ch) : Dense (ch. first, ch. second)
163
- activation_func (x) = σ .(x )
136
+ linear = permuted ? Conv (Tuple (ones (Int, length (modes))), ch) : Dense (ch. first, ch. second)
137
+ conv = SpectralConv (ch, modes; permuted = permuted )
164
138
165
- return Chain (
166
- Parallel (+ , short_cut, SpectralConv (ch, modes, permuted= permuted)),
167
- activation_func
168
- )
139
+ return FourierOperator (linear, conv, σ)
140
+ end
141
+
142
+ Flux. @functor FourierOperator
143
+
144
+ function Base. show (io:: IO , l:: FourierOperator )
145
+ print (io, " FourierOperator($(l. conv. in_channel) => $(l. conv. out_channel) , $(l. conv. modes) , σ=$(string (l. σ)) , permuted=$(permuted (l. conv)) )" )
169
146
end
170
147
148
+ function (m:: FourierOperator )(𝐱)
149
+ return m. σ .(m. linear (𝐱) + m. conv (𝐱))
150
+ end
151
+
152
+
171
153
# ########
172
154
# utils #
173
155
# ########
0 commit comments