Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit d707acf

Browse files
committed
Remove modes from OperatorConv
1 parent e764d5b commit d707acf

File tree

2 files changed

+9
-13
lines changed

2 files changed

+9
-13
lines changed

src/fourier.jl

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,20 @@ export
33
SpectralConv,
44
OperatorKernel
55

6-
struct OperatorConv{P, N, T, S, TT}
6+
struct OperatorConv{P, T, S, TT}
77
weight::T
88
in_channel::S
99
out_channel::S
10-
modes::NTuple{N, S}
1110
transform::TT
1211
end
1312

1413
function OperatorConv{P}(
1514
weight::T,
1615
in_channel::S,
1716
out_channel::S,
18-
modes::NTuple{N, S},
1917
transform::TT
20-
) where {P, N, T, S, TT<:AbstractTransform}
21-
return OperatorConv{P, N, T, S, TT}(weight, in_channel, out_channel, modes, transform)
18+
) where {P, T, S, TT<:AbstractTransform}
19+
return OperatorConv{P, T, S, TT}(weight, in_channel, out_channel, transform)
2220
end
2321

2422
"""
@@ -58,7 +56,7 @@ function OperatorConv(
5856
weights = scale * init(prod(modes), in_chs, out_chs)
5957
transform = Transform(modes)
6058

61-
return OperatorConv{permuted}(weights, in_chs, out_chs, modes, transform)
59+
return OperatorConv{permuted}(weights, in_chs, out_chs, transform)
6260
end
6361

6462
function SpectralConv(
@@ -74,17 +72,15 @@ end
7472
Flux.@functor OperatorConv{true}
7573
Flux.@functor OperatorConv{false}
7674

77-
Base.ndims(::OperatorConv{P, N}) where {P, N} = N
75+
Base.ndims(oc::OperatorConv) = ndims(oc.transform)
7876

7977
ispermuted(::OperatorConv{P}) where {P} = P
8078

8179
function Base.show(io::IO, l::OperatorConv{P}) where {P}
82-
print(io, "OperatorConv($(l.in_channel) => $(l.out_channel), $(l.modes), permuted=$P)")
80+
print(io, "OperatorConv($(l.in_channel) => $(l.out_channel), $(l.transform.modes), $(nameof(typeof(l.transform))), permuted=$P)")
8381
end
8482

8583
function operator_conv(m::OperatorConv, 𝐱::AbstractArray)
86-
# ft = FourierTransform(m.modes)
87-
8884
𝐱_transformed = transform(m.transform, 𝐱) # [size(x)..., in_chs, batch]
8985
𝐱_truncated = truncate_modes(m.transform, 𝐱_transformed) # [modes..., in_chs, batch]
9086
𝐱_applied_pattern = apply_pattern(𝐱_truncated, m.weight) # [modes..., out_chs, batch]
@@ -162,7 +158,7 @@ function Base.show(io::IO, l::OperatorKernel)
162158
io,
163159
"OperatorKernel(" *
164160
"$(l.conv.in_channel) => $(l.conv.out_channel), " *
165-
"$(l.conv.modes), " *
161+
"$(l.conv.transform.modes), " *
166162
"$(nameof(typeof(l.conv.transform))), " *
167163
"σ=$(string(l.σ)), " *
168164
"permuted=$(ispermuted(l.conv))" *

test/fourier.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
OperatorConv(ch, modes, FourierTransform)
88
)
99
@test ndims(OperatorConv(ch, modes, FourierTransform)) == 1
10-
@test repr(OperatorConv(ch, modes, FourierTransform)) == "OperatorConv(64 => 128, (16,), permuted=false)"
10+
@test repr(OperatorConv(ch, modes, FourierTransform)) == "OperatorConv(64 => 128, (16,), FourierTransform, permuted=false)"
1111

1212
𝐱 = rand(Float32, 2, 1024, 5)
1313
@test size(m(𝐱)) == (128, 1024, 5)
@@ -26,7 +26,7 @@ end
2626
OperatorConv(ch, modes, FourierTransform, permuted=true)
2727
)
2828
@test ndims(OperatorConv(ch, modes, FourierTransform, permuted=true)) == 1
29-
@test repr(OperatorConv(ch, modes, FourierTransform, permuted=true)) == "OperatorConv(64 => 128, (16,), permuted=true)"
29+
@test repr(OperatorConv(ch, modes, FourierTransform, permuted=true)) == "OperatorConv(64 => 128, (16,), FourierTransform, permuted=true)"
3030

3131
𝐱 = rand(Float32, 2, 1024, 5)
3232
𝐱 = permutedims(𝐱, (2, 1, 3))

0 commit comments

Comments
 (0)