|
3 | 3 | SpectralConv,
|
4 | 4 | OperatorKernel
|
5 | 5 |
|
6 |
| -struct OperatorConv{P, N, T, S, TT} |
| 6 | +struct OperatorConv{P, T, S, TT} |
7 | 7 | weight::T
|
8 | 8 | in_channel::S
|
9 | 9 | out_channel::S
|
10 |
| - modes::NTuple{N, S} |
11 | 10 | transform::TT
|
12 | 11 | end
|
13 | 12 |
|
14 | 13 | function OperatorConv{P}(
|
15 | 14 | weight::T,
|
16 | 15 | in_channel::S,
|
17 | 16 | out_channel::S,
|
18 |
| - modes::NTuple{N, S}, |
19 | 17 | transform::TT
|
20 |
| -) where {P, N, T, S, TT<:AbstractTransform} |
21 |
| - return OperatorConv{P, N, T, S, TT}(weight, in_channel, out_channel, modes, transform) |
| 18 | +) where {P, T, S, TT<:AbstractTransform} |
| 19 | + return OperatorConv{P, T, S, TT}(weight, in_channel, out_channel, transform) |
22 | 20 | end
|
23 | 21 |
|
24 | 22 | """
|
@@ -58,7 +56,7 @@ function OperatorConv(
|
58 | 56 | weights = scale * init(prod(modes), in_chs, out_chs)
|
59 | 57 | transform = Transform(modes)
|
60 | 58 |
|
61 |
| - return OperatorConv{permuted}(weights, in_chs, out_chs, modes, transform) |
| 59 | + return OperatorConv{permuted}(weights, in_chs, out_chs, transform) |
62 | 60 | end
|
63 | 61 |
|
64 | 62 | function SpectralConv(
|
|
74 | 72 | Flux.@functor OperatorConv{true}
|
75 | 73 | Flux.@functor OperatorConv{false}
|
76 | 74 |
|
77 |
| -Base.ndims(::OperatorConv{P, N}) where {P, N} = N |
| 75 | +Base.ndims(oc::OperatorConv) = ndims(oc.transform) |
78 | 76 |
|
79 | 77 | ispermuted(::OperatorConv{P}) where {P} = P
|
80 | 78 |
|
81 | 79 | function Base.show(io::IO, l::OperatorConv{P}) where {P}
|
82 |
| - print(io, "OperatorConv($(l.in_channel) => $(l.out_channel), $(l.modes), permuted=$P)") |
| 80 | + print(io, "OperatorConv($(l.in_channel) => $(l.out_channel), $(l.transform.modes), $(nameof(typeof(l.transform))), permuted=$P)") |
83 | 81 | end
|
84 | 82 |
|
85 | 83 | function operator_conv(m::OperatorConv, 𝐱::AbstractArray)
|
86 |
| - # ft = FourierTransform(m.modes) |
87 |
| - |
88 | 84 | 𝐱_transformed = transform(m.transform, 𝐱) # [size(x)..., in_chs, batch]
|
89 | 85 | 𝐱_truncated = truncate_modes(m.transform, 𝐱_transformed) # [modes..., in_chs, batch]
|
90 | 86 | 𝐱_applied_pattern = apply_pattern(𝐱_truncated, m.weight) # [modes..., out_chs, batch]
|
@@ -162,7 +158,7 @@ function Base.show(io::IO, l::OperatorKernel)
|
162 | 158 | io,
|
163 | 159 | "OperatorKernel(" *
|
164 | 160 | "$(l.conv.in_channel) => $(l.conv.out_channel), " *
|
165 |
| - "$(l.conv.modes), " * |
| 161 | + "$(l.conv.transform.modes), " * |
166 | 162 | "$(nameof(typeof(l.conv.transform))), " *
|
167 | 163 | "σ=$(string(l.σ)), " *
|
168 | 164 | "permuted=$(ispermuted(l.conv))" *
|
|
0 commit comments