Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit e309f0f

Browse files
committed
refactor permuted
1 parent e0215fa commit e309f0f

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

src/fourier.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,22 @@ export
33
FourierOperator
44

55
struct SpectralConv{P, N, T, S, F}
6-
permuted::Bool
76
weight::T
87
in_channel::S
98
out_channel::S
109
modes::NTuple{N, S}
1110
σ::F
12-
end
1311

14-
function SpectralConv(
15-
permuted::Bool,
16-
weight::T,
17-
in_channel::S,
18-
out_channel::S,
19-
modes::NTuple{N, S},
20-
σ::F
21-
) where {N, T, S, F}
22-
return SpectralConv{permuted, N, T, S, F}(permuted, weight, in_channel, out_channel, modes, σ)
12+
function SpectralConv(
13+
permuted::Bool,
14+
weight::T,
15+
in_channel::S,
16+
out_channel::S,
17+
modes::NTuple{N, S},
18+
σ::F
19+
) where {N, T, S, F}
20+
return new{permuted, N, T, S, F}(weight, in_channel, out_channel, modes, σ)
21+
end
2322
end
2423

2524
"""
@@ -70,6 +69,8 @@ Flux.@functor SpectralConv
7069

7170
Base.ndims(::SpectralConv{P, N}) where {P, N} = N
7271

72+
ispermuted(::SpectralConv{P}) where {P} = P
73+
7374
function Base.show(io::IO, l::SpectralConv{P}) where {P}
7475
print(io, "SpectralConv($(l.in_channel) => $(l.out_channel), $(l.modes), σ=$(string(l.σ)), permuted=$P)")
7576
end
@@ -156,7 +157,7 @@ function Base.show(io::IO, l::FourierOperator)
156157
"$(l.conv.in_channel) => $(l.conv.out_channel), " *
157158
"$(l.conv.modes), " *
158159
"σ=$(string(l.σ)), " *
159-
"permuted=$(l.conv.permuted)" *
160+
"permuted=$(ispermuted(l.conv))" *
160161
")"
161162
)
162163
end

0 commit comments

Comments
 (0)