28
28
29
29
```jldoctest
30
30
julia> SpectralConv(2=>5, (16, ))
31
- SpectralConv(2 => 5, (16,), σ=identity)
31
+ SpectralConv(2 => 5, (16,), σ=identity, permuted=false )
32
32
33
33
julia> using Flux
34
34
35
35
julia> SpectralConv(2=>5, (16, ), relu)
36
- SpectralConv(2 => 5, (16,), σ=relu)
36
+ SpectralConv(2 => 5, (16,), σ=relu, permuted=false )
37
37
38
38
julia> SpectralConv(2=>5, (16, ), relu, permuted=true)
39
- SpectralConv(2 => 5, (16,), σ=relu)
39
+ SpectralConv(2 => 5, (16,), σ=relu, permuted=true )
40
40
```
41
41
"""
42
42
function SpectralConv (
@@ -60,8 +60,10 @@ Flux.@functor SpectralConv
60
60
61
61
Base. ndims (:: SpectralConv{P,N} ) where {P,N} = N
62
62
63
- function Base. show (io:: IO , l:: SpectralConv )
64
- print (io, " SpectralConv($(l. in_channel) => $(l. out_channel) , $(l. modes) , σ=$(string (l. σ)) )" )
63
+ permuted (:: SpectralConv{P} ) where {P} = P
64
+
65
+ function Base. show (io:: IO , l:: SpectralConv{P} ) where {P}
66
+ print (io, " SpectralConv($(l. in_channel) => $(l. out_channel) , $(l. modes) , σ=$(string (l. σ)) , permuted=$P )" )
65
67
end
66
68
67
69
function spectral_conv (m:: SpectralConv , 𝐱:: AbstractArray )
@@ -114,36 +116,15 @@ end
114
116
115
117
```jldoctest
116
118
julia> FourierOperator(2=>5, (16, ))
117
- Chain(
118
- Parallel(
119
- +,
120
- Dense(2, 5), # 15 parameters
121
- SpectralConv(2 => 5, (16,), σ=identity), # 160 parameters
122
- ),
123
- NeuralOperators.var"#activation_func#14"{typeof(identity)}(identity),
124
- ) # Total: 3 arrays, 175 parameters, 1.668 KiB.
119
+ FourierOperator(2 => 5, (16,), σ=identity, permuted=false)
125
120
126
121
julia> using Flux
127
122
128
123
julia> FourierOperator(2=>5, (16, ), relu)
129
- Chain(
130
- Parallel(
131
- +,
132
- Dense(2, 5), # 15 parameters
133
- SpectralConv(2 => 5, (16,), σ=identity), # 160 parameters
134
- ),
135
- NeuralOperators.var"#activation_func#14"{typeof(relu)}(NNlib.relu),
136
- ) # Total: 3 arrays, 175 parameters, 1.668 KiB.
124
+ FourierOperator(2 => 5, (16,), σ=relu, permuted=false)
137
125
138
126
julia> FourierOperator(2=>5, (16, ), relu, permuted=true)
139
- Chain(
140
- Parallel(
141
- +,
142
- Conv((1,), 2 => 5), # 15 parameters
143
- SpectralConvPerm(2 => 5, (16,), σ=identity), # 160 parameters
144
- ),
145
- NeuralOperators.var"#activation_func#14"{typeof(relu)}(NNlib.relu),
146
- ) # Total: 3 arrays, 175 parameters, 1.871 KiB.
127
+ FourierOperator(2 => 5, (16,), σ=relu, permuted=true)
147
128
```
148
129
"""
149
130
function FourierOperator (
160
141
161
142
Flux. @functor FourierOperator
162
143
144
+ function Base. show (io:: IO , l:: FourierOperator )
145
+ print (io, " FourierOperator($(l. conv. in_channel) => $(l. conv. out_channel) , $(l. conv. modes) , σ=$(string (l. σ)) , permuted=$(permuted (l. conv)) )" )
146
+ end
147
+
163
148
function (m:: FourierOperator )(𝐱)
164
149
return m. σ .(m. linear (𝐱) + m. conv (𝐱))
165
150
end
0 commit comments