2
2
SpectralConv,
3
3
FourierOperator
4
4
5
- struct SpectralConv{P, N, T, S, F }
5
+ struct SpectralConv{P, N, T, S}
6
6
weight:: T
7
7
in_channel:: S
8
8
out_channel:: S
9
9
modes:: NTuple{N, S}
10
- σ:: F
11
10
end
12
11
13
12
function SpectralConv {P} (
14
13
weight:: T ,
15
14
in_channel:: S ,
16
15
out_channel:: S ,
17
16
modes:: NTuple{N, S} ,
18
- σ:: F
19
- ) where {P, N, T, S, F}
20
- return SpectralConv {P, N, T, S, F} (weight, in_channel, out_channel, modes, σ)
17
+ ) where {P, N, T, S}
18
+ return SpectralConv {P, N, T, S} (weight, in_channel, out_channel, modes)
21
19
end
22
20
23
21
"""
24
22
SpectralConv(
25
- ch, modes, σ=identity ;
23
+ ch, modes;
26
24
init=c_glorot_uniform, permuted=false, T=ComplexF32
27
25
)
28
26
29
27
## Arguments
30
28
31
29
* `ch`: Input and output channel size, e.g. `64=>64`.
32
30
* `modes`: The Fourier modes to be preserved.
33
- * `σ`: Activation function.
34
31
* `permuted`: Whether the dim is permuted. If `permuted=true`, layer accepts
35
32
data in the order of `(ch, ..., batch)`, otherwise the order is `(..., ch, batch)`.
36
33
37
34
## Example
38
35
39
36
```jldoctest
40
37
julia> SpectralConv(2=>5, (16, ))
41
- SpectralConv(2 => 5, (16,), σ=identity, permuted=false)
42
-
43
- julia> using Flux
44
-
45
- julia> SpectralConv(2=>5, (16, ), relu)
46
- SpectralConv(2 => 5, (16,), σ=relu, permuted=false)
38
+ SpectralConv(2 => 5, (16,), permuted=false)
47
39
48
- julia> SpectralConv(2=>5, (16, ), relu, permuted=true)
49
- SpectralConv(2 => 5, (16,), σ=relu, permuted=true)
40
+ julia> SpectralConv(2=>5, (16, ), permuted=true)
41
+ SpectralConv(2 => 5, (16,), permuted=true)
50
42
```
51
43
"""
52
44
function SpectralConv (
53
45
ch:: Pair{S, S} ,
54
- modes:: NTuple{N, S} ,
55
- σ= identity;
46
+ modes:: NTuple{N, S} ;
56
47
init= c_glorot_uniform,
57
48
permuted= false ,
58
49
T:: DataType = ComplexF32
@@ -61,7 +52,7 @@ function SpectralConv(
61
52
scale = one (T) / (in_chs * out_chs)
62
53
weights = scale * init (prod (modes), in_chs, out_chs)
63
54
64
- return SpectralConv {permuted} (weights, in_chs, out_chs, modes, σ )
55
+ return SpectralConv {permuted} (weights, in_chs, out_chs, modes)
65
56
end
66
57
67
58
Flux. @functor SpectralConv{true }
@@ -72,7 +63,7 @@ Base.ndims(::SpectralConv{P, N}) where {P, N} = N
72
63
ispermuted (:: SpectralConv{P} ) where {P} = P
73
64
74
65
function Base. show (io:: IO , l:: SpectralConv{P} ) where {P}
75
- print (io, " SpectralConv($(l. in_channel) => $(l. out_channel) , $(l. modes) , σ= $( string (l . σ)) , permuted=$P )" )
66
+ print (io, " SpectralConv($(l. in_channel) => $(l. out_channel) , $(l. modes) , permuted=$P )" )
76
67
end
77
68
78
69
function spectral_conv (m:: SpectralConv , 𝐱:: AbstractArray )
@@ -85,7 +76,7 @@ function spectral_conv(m::SpectralConv, 𝐱::AbstractArray)
85
76
𝐱_padded = spectral_pad (𝐱_shaped, (size (𝐱_fft)[1 : end - 2 ]. .. , size (𝐱_weighted, 2 ), size (𝐱_weighted, 3 ))) # [x, out_chs, batch] <- [modes, out_chs, batch]
86
77
𝐱_ifft = real (ifft (𝐱_padded, 1 : ndims (m))) # [x, out_chs, batch]
87
78
88
- return m . σ .( 𝐱_ifft)
79
+ return 𝐱_ifft
89
80
end
90
81
91
82
function (m:: SpectralConv{false} )(𝐱)
0 commit comments