Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit b20ab6a

Browse files
committed
revise show
1 parent 9cc4e81 commit b20ab6a

File tree

2 files changed

+18
-30
lines changed

2 files changed

+18
-30
lines changed

src/fourier.jl

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@ end
2828
2929
```jldoctest
3030
julia> SpectralConv(2=>5, (16, ))
31-
SpectralConv(2 => 5, (16,), σ=identity)
31+
SpectralConv(2 => 5, (16,), σ=identity, permuted=false)
3232
3333
julia> using Flux
3434
3535
julia> SpectralConv(2=>5, (16, ), relu)
36-
SpectralConv(2 => 5, (16,), σ=relu)
36+
SpectralConv(2 => 5, (16,), σ=relu, permuted=false)
3737
3838
julia> SpectralConv(2=>5, (16, ), relu, permuted=true)
39-
SpectralConv(2 => 5, (16,), σ=relu)
39+
SpectralConv(2 => 5, (16,), σ=relu, permuted=true)
4040
```
4141
"""
4242
function SpectralConv(
@@ -60,8 +60,10 @@ Flux.@functor SpectralConv
6060

6161
Base.ndims(::SpectralConv{P,N}) where {P,N} = N
6262

63-
function Base.show(io::IO, l::SpectralConv)
64-
print(io, "SpectralConv($(l.in_channel) => $(l.out_channel), $(l.modes), σ=$(string(l.σ)))")
63+
permuted(::SpectralConv{P}) where {P} = P
64+
65+
function Base.show(io::IO, l::SpectralConv{P}) where {P}
66+
print(io, "SpectralConv($(l.in_channel) => $(l.out_channel), $(l.modes), σ=$(string(l.σ)), permuted=$P)")
6567
end
6668

6769
function spectral_conv(m::SpectralConv, 𝐱::AbstractArray)
@@ -114,36 +116,15 @@ end
114116
115117
```jldoctest
116118
julia> FourierOperator(2=>5, (16, ))
117-
Chain(
118-
Parallel(
119-
+,
120-
Dense(2, 5), # 15 parameters
121-
SpectralConv(2 => 5, (16,), σ=identity), # 160 parameters
122-
),
123-
NeuralOperators.var"#activation_func#14"{typeof(identity)}(identity),
124-
) # Total: 3 arrays, 175 parameters, 1.668 KiB.
119+
FourierOperator(2 => 5, (16,), σ=identity, permuted=false)
125120
126121
julia> using Flux
127122
128123
julia> FourierOperator(2=>5, (16, ), relu)
129-
Chain(
130-
Parallel(
131-
+,
132-
Dense(2, 5), # 15 parameters
133-
SpectralConv(2 => 5, (16,), σ=identity), # 160 parameters
134-
),
135-
NeuralOperators.var"#activation_func#14"{typeof(relu)}(NNlib.relu),
136-
) # Total: 3 arrays, 175 parameters, 1.668 KiB.
124+
FourierOperator(2 => 5, (16,), σ=relu, permuted=false)
137125
138126
julia> FourierOperator(2=>5, (16, ), relu, permuted=true)
139-
Chain(
140-
Parallel(
141-
+,
142-
Conv((1,), 2 => 5), # 15 parameters
143-
SpectralConvPerm(2 => 5, (16,), σ=identity), # 160 parameters
144-
),
145-
NeuralOperators.var"#activation_func#14"{typeof(relu)}(NNlib.relu),
146-
) # Total: 3 arrays, 175 parameters, 1.871 KiB.
127+
FourierOperator(2 => 5, (16,), σ=relu, permuted=true)
147128
```
148129
"""
149130
function FourierOperator(
@@ -160,6 +141,10 @@ end
160141

161142
Flux.@functor FourierOperator
162143

144+
function Base.show(io::IO, l::FourierOperator)
145+
print(io, "FourierOperator($(l.conv.in_channel) => $(l.conv.out_channel), $(l.conv.modes), σ=$(string(l.σ)), permuted=$(permuted(l.conv)))")
146+
end
147+
163148
function (m::FourierOperator)(𝐱)
164149
return m.σ.(m.linear(𝐱) + m.conv(𝐱))
165150
end

test/fourier.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
SpectralConv(ch, modes)
88
)
99
@test ndims(SpectralConv(ch, modes)) == 1
10-
@test repr(SpectralConv(ch, modes)) == "SpectralConv(64 => 128, (16,), σ=identity)"
10+
@test repr(SpectralConv(ch, modes)) == "SpectralConv(64 => 128, (16,), σ=identity, permuted=false)"
1111

1212
𝐱, _ = get_burgers_data(n=5)
1313
@test size(m(𝐱)) == (128, 1024, 5)
@@ -26,6 +26,7 @@ end
2626
SpectralConv(ch, modes, permuted=true)
2727
)
2828
@test ndims(SpectralConv(ch, modes, permuted=true)) == 1
29+
@test repr(SpectralConv(ch, modes, permuted=true)) == "SpectralConv(64 => 128, (16,), σ=identity, permuted=true)"
2930

3031
𝐱, _ = get_burgers_data(n=5)
3132
𝐱 = permutedims(𝐱, (2, 1, 3))
@@ -44,6 +45,7 @@ end
4445
Dense(2, 64),
4546
FourierOperator(ch, modes)
4647
)
48+
@test repr(FourierOperator(ch, modes)) == "FourierOperator(64 => 128, (16,), σ=identity, permuted=false)"
4749

4850
𝐱, _ = get_burgers_data(n=5)
4951
@test size(m(𝐱)) == (128, 1024, 5)
@@ -61,6 +63,7 @@ end
6163
Conv((1, ), 2=>64),
6264
FourierOperator(ch, modes, permuted=true)
6365
)
66+
@test repr(FourierOperator(ch, modes, permuted=true)) == "FourierOperator(64 => 128, (16,), σ=identity, permuted=true)"
6467

6568
𝐱, _ = get_burgers_data(n=5)
6669
𝐱 = permutedims(𝐱, (2, 1, 3))

0 commit comments

Comments
 (0)