Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 4beb3f8

Browse files
committed
rm activation field from SpectralConv
1 parent 04f2643 commit 4beb3f8

File tree

2 files changed

+13
-22
lines changed

2 files changed

+13
-22
lines changed

src/fourier.jl

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,57 +2,48 @@ export
22
SpectralConv,
33
FourierOperator
44

5-
struct SpectralConv{P, N, T, S, F}
5+
struct SpectralConv{P, N, T, S}
66
weight::T
77
in_channel::S
88
out_channel::S
99
modes::NTuple{N, S}
10-
σ::F
1110
end
1211

1312
function SpectralConv{P}(
1413
weight::T,
1514
in_channel::S,
1615
out_channel::S,
1716
modes::NTuple{N, S},
18-
σ::F
19-
) where {P, N, T, S, F}
20-
return SpectralConv{P, N, T, S, F}(weight, in_channel, out_channel, modes, σ)
17+
) where {P, N, T, S}
18+
return SpectralConv{P, N, T, S}(weight, in_channel, out_channel, modes)
2119
end
2220

2321
"""
2422
SpectralConv(
25-
ch, modes, σ=identity;
23+
ch, modes;
2624
init=c_glorot_uniform, permuted=false, T=ComplexF32
2725
)
2826
2927
## Arguments
3028
3129
* `ch`: Input and output channel size, e.g. `64=>64`.
3230
* `modes`: The Fourier modes to be preserved.
33-
* `σ`: Activation function.
3431
* `permuted`: Whether the dim is permuted. If `permuted=true`, layer accepts
3532
data in the order of `(ch, ..., batch)`, otherwise the order is `(..., ch, batch)`.
3633
3734
## Example
3835
3936
```jldoctest
4037
julia> SpectralConv(2=>5, (16, ))
41-
SpectralConv(2 => 5, (16,), σ=identity, permuted=false)
42-
43-
julia> using Flux
44-
45-
julia> SpectralConv(2=>5, (16, ), relu)
46-
SpectralConv(2 => 5, (16,), σ=relu, permuted=false)
38+
SpectralConv(2 => 5, (16,), permuted=false)
4739
48-
julia> SpectralConv(2=>5, (16, ), relu, permuted=true)
49-
SpectralConv(2 => 5, (16,), σ=relu, permuted=true)
40+
julia> SpectralConv(2=>5, (16, ), permuted=true)
41+
SpectralConv(2 => 5, (16,), permuted=true)
5042
```
5143
"""
5244
function SpectralConv(
5345
ch::Pair{S, S},
54-
modes::NTuple{N, S},
55-
σ=identity;
46+
modes::NTuple{N, S};
5647
init=c_glorot_uniform,
5748
permuted=false,
5849
T::DataType=ComplexF32
@@ -61,7 +52,7 @@ function SpectralConv(
6152
scale = one(T) / (in_chs * out_chs)
6253
weights = scale * init(prod(modes), in_chs, out_chs)
6354

64-
return SpectralConv{permuted}(weights, in_chs, out_chs, modes, σ)
55+
return SpectralConv{permuted}(weights, in_chs, out_chs, modes)
6556
end
6657

6758
Flux.@functor SpectralConv{true}
@@ -72,7 +63,7 @@ Base.ndims(::SpectralConv{P, N}) where {P, N} = N
7263
ispermuted(::SpectralConv{P}) where {P} = P
7364

7465
function Base.show(io::IO, l::SpectralConv{P}) where {P}
75-
print(io, "SpectralConv($(l.in_channel) => $(l.out_channel), $(l.modes), σ=$(string(l.σ)), permuted=$P)")
66+
print(io, "SpectralConv($(l.in_channel) => $(l.out_channel), $(l.modes), permuted=$P)")
7667
end
7768

7869
function spectral_conv(m::SpectralConv, 𝐱::AbstractArray)
@@ -85,7 +76,7 @@ function spectral_conv(m::SpectralConv, 𝐱::AbstractArray)
8576
𝐱_padded = spectral_pad(𝐱_shaped, (size(𝐱_fft)[1:end-2]..., size(𝐱_weighted, 2), size(𝐱_weighted, 3))) # [x, out_chs, batch] <- [modes, out_chs, batch]
8677
𝐱_ifft = real(ifft(𝐱_padded, 1:ndims(m))) # [x, out_chs, batch]
8778

88-
return m.σ.(𝐱_ifft)
79+
return 𝐱_ifft
8980
end
9081

9182
function (m::SpectralConv{false})(𝐱)

test/fourier.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
SpectralConv(ch, modes)
88
)
99
@test ndims(SpectralConv(ch, modes)) == 1
10-
@test repr(SpectralConv(ch, modes)) == "SpectralConv(64 => 128, (16,), σ=identity, permuted=false)"
10+
@test repr(SpectralConv(ch, modes)) == "SpectralConv(64 => 128, (16,), permuted=false)"
1111

1212
𝐱 = rand(Float32, 2, 1024, 5)
1313
@test size(m(𝐱)) == (128, 1024, 5)
@@ -26,7 +26,7 @@ end
2626
SpectralConv(ch, modes, permuted=true)
2727
)
2828
@test ndims(SpectralConv(ch, modes, permuted=true)) == 1
29-
@test repr(SpectralConv(ch, modes, permuted=true)) == "SpectralConv(64 => 128, (16,), σ=identity, permuted=true)"
29+
@test repr(SpectralConv(ch, modes, permuted=true)) == "SpectralConv(64 => 128, (16,), permuted=true)"
3030

3131
𝐱 = rand(Float32, 2, 1024, 5)
3232
𝐱 = permutedims(𝐱, (2, 1, 3))

0 commit comments

Comments
 (0)