|
3 | 3 | FourierOperator
|
4 | 4 |
|
5 | 5 | struct SpectralConv{P, N, T, S, F}
|
6 |
| - permuted::Bool |
7 | 6 | weight::T
|
8 | 7 | in_channel::S
|
9 | 8 | out_channel::S
|
10 | 9 | modes::NTuple{N, S}
|
11 | 10 | σ::F
|
12 | 11 | end
|
13 | 12 |
|
14 |
| -function SpectralConv( |
15 |
| - permuted::Bool, |
| 13 | +function SpectralConv{P}( |
16 | 14 | weight::T,
|
17 | 15 | in_channel::S,
|
18 | 16 | out_channel::S,
|
19 | 17 | modes::NTuple{N, S},
|
20 | 18 | σ::F
|
21 |
| -) where {N, T, S, F} |
22 |
| - return SpectralConv{permuted, N, T, S, F}(permuted, weight, in_channel, out_channel, modes, σ) |
| 19 | +) where {P, N, T, S, F} |
| 20 | + return SpectralConv{P, N, T, S, F}(weight, in_channel, out_channel, modes, σ) |
23 | 21 | end
|
24 | 22 |
|
25 | 23 | """
|
@@ -63,13 +61,16 @@ function SpectralConv(
|
63 | 61 | scale = one(T) / (in_chs * out_chs)
|
64 | 62 | weights = scale * init(out_chs, in_chs, prod(modes))
|
65 | 63 |
|
66 |
| - return SpectralConv(permuted, weights, in_chs, out_chs, modes, σ) |
| 64 | + return SpectralConv{permuted}(weights, in_chs, out_chs, modes, σ) |
67 | 65 | end
|
68 | 66 |
|
69 |
| -Flux.@functor SpectralConv |
| 67 | +Flux.@functor SpectralConv{true} |
| 68 | +Flux.@functor SpectralConv{false} |
70 | 69 |
|
71 | 70 | Base.ndims(::SpectralConv{P, N}) where {P, N} = N
|
72 | 71 |
|
| 72 | +ispermuted(::SpectralConv{P}) where {P} = P |
| 73 | + |
73 | 74 | function Base.show(io::IO, l::SpectralConv{P}) where {P}
|
74 | 75 | print(io, "SpectralConv($(l.in_channel) => $(l.out_channel), $(l.modes), σ=$(string(l.σ)), permuted=$P)")
|
75 | 76 | end
|
@@ -156,7 +157,7 @@ function Base.show(io::IO, l::FourierOperator)
|
156 | 157 | "$(l.conv.in_channel) => $(l.conv.out_channel), " *
|
157 | 158 | "$(l.conv.modes), " *
|
158 | 159 | "σ=$(string(l.σ)), " *
|
159 |
| - "permuted=$(l.conv.permuted)" * |
| 160 | + "permuted=$(ispermuted(l.conv))" * |
160 | 161 | ")"
|
161 | 162 | )
|
162 | 163 | end
|
|
0 commit comments