Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 9fd713b

Browse files
authored
Merge pull request #18 from yuehhua/permuted
Improve permuted
2 parents e0215fa + 93e1b64 commit 9fd713b

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

src/fourier.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,21 @@ export
33
FourierOperator
44

55
struct SpectralConv{P, N, T, S, F}
6-
permuted::Bool
76
weight::T
87
in_channel::S
98
out_channel::S
109
modes::NTuple{N, S}
1110
σ::F
1211
end
1312

14-
function SpectralConv(
15-
permuted::Bool,
13+
function SpectralConv{P}(
1614
weight::T,
1715
in_channel::S,
1816
out_channel::S,
1917
modes::NTuple{N, S},
2018
σ::F
21-
) where {N, T, S, F}
22-
return SpectralConv{permuted, N, T, S, F}(permuted, weight, in_channel, out_channel, modes, σ)
19+
) where {P, N, T, S, F}
20+
return SpectralConv{P, N, T, S, F}(weight, in_channel, out_channel, modes, σ)
2321
end
2422

2523
"""
@@ -63,13 +61,16 @@ function SpectralConv(
6361
scale = one(T) / (in_chs * out_chs)
6462
weights = scale * init(out_chs, in_chs, prod(modes))
6563

66-
return SpectralConv(permuted, weights, in_chs, out_chs, modes, σ)
64+
return SpectralConv{permuted}(weights, in_chs, out_chs, modes, σ)
6765
end
6866

69-
Flux.@functor SpectralConv
67+
Flux.@functor SpectralConv{true}
68+
Flux.@functor SpectralConv{false}
7069

7170
Base.ndims(::SpectralConv{P, N}) where {P, N} = N
7271

72+
ispermuted(::SpectralConv{P}) where {P} = P
73+
7374
function Base.show(io::IO, l::SpectralConv{P}) where {P}
7475
print(io, "SpectralConv($(l.in_channel) => $(l.out_channel), $(l.modes), σ=$(string(l.σ)), permuted=$P)")
7576
end
@@ -156,7 +157,7 @@ function Base.show(io::IO, l::FourierOperator)
156157
"$(l.conv.in_channel) => $(l.conv.out_channel), " *
157158
"$(l.conv.modes), " *
158159
"σ=$(string(l.σ)), " *
159-
"permuted=$(l.conv.permuted)" *
160+
"permuted=$(ispermuted(l.conv))" *
160161
")"
161162
)
162163
end

0 commit comments

Comments
 (0)