Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit e764d5b

Browse files
committed
Update show for OperatorKernel
1 parent 0185b1f commit e764d5b

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

src/fourier.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,15 +131,15 @@ end
131131
132132
```jldoctest
133133
julia> OperatorKernel(2=>5, (16, ), FourierTransform)
134-
OperatorKernel(2 => 5, (16,), σ=identity, permuted=false)
134+
OperatorKernel(2 => 5, (16,), FourierTransform, σ=identity, permuted=false)
135135
136136
julia> using Flux
137137
138138
julia> OperatorKernel(2=>5, (16, ), FourierTransform, relu)
139-
OperatorKernel(2 => 5, (16,), σ=relu, permuted=false)
139+
OperatorKernel(2 => 5, (16,), FourierTransform, σ=relu, permuted=false)
140140
141141
julia> OperatorKernel(2=>5, (16, ), FourierTransform, relu, permuted=true)
142-
OperatorKernel(2 => 5, (16,), σ=relu, permuted=true)
142+
OperatorKernel(2 => 5, (16,), FourierTransform, σ=relu, permuted=true)
143143
```
144144
"""
145145
function OperatorKernel(
@@ -163,6 +163,7 @@ function Base.show(io::IO, l::OperatorKernel)
163163
"OperatorKernel(" *
164164
"$(l.conv.in_channel) => $(l.conv.out_channel), " *
165165
"$(l.conv.modes), " *
166+
"$(nameof(typeof(l.conv.transform))), " *
166167
"σ=$(string(l.σ)), " *
167168
"permuted=$(ispermuted(l.conv))" *
168169
")"
@@ -185,9 +186,9 @@ einsum(𝐱₁, 𝐱₂) = @tullio 𝐲[m, o, b] := 𝐱₁[m, i, b] * 𝐱₂[m
185186
function apply_pattern(𝐱_truncated, 𝐰)
186187
x_size = size(𝐱_truncated) # [m.modes..., in_chs, batch]
187188

188-
𝐱_flattened = reshape(𝐱_truncated, :, x_size[end-1:end]...) # [prod(m.modes), out_chs, batch], only 3-dims
189+
𝐱_flattened = reshape(𝐱_truncated, :, x_size[end-1:end]...) # [prod(m.modes), in_chs, batch], only 3-dims
189190
𝐱_weighted = einsum(𝐱_flattened, 𝐰) # [prod(m.modes), out_chs, batch], only 3-dims
190-
𝐱_shaped = reshape(𝐱_weighted, x_size[1:end-2]..., size(𝐱_weighted, 2), size(𝐱_weighted, 3)) # [m.modes..., out_chs, batch]
191+
𝐱_shaped = reshape(𝐱_weighted, x_size[1:end-2]..., size(𝐱_weighted)[2:3]...) # [m.modes..., out_chs, batch]
191192

192193
return 𝐱_shaped
193194
end

test/fourier.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ end
4545
Dense(2, 64),
4646
OperatorKernel(ch, modes, FourierTransform)
4747
)
48-
@test repr(OperatorKernel(ch, modes, FourierTransform)) == "OperatorKernel(64 => 128, (16,), σ=identity, permuted=false)"
48+
@test repr(OperatorKernel(ch, modes, FourierTransform)) == "OperatorKernel(64 => 128, (16,), FourierTransform, σ=identity, permuted=false)"
4949

5050
𝐱 = rand(Float32, 2, 1024, 5)
5151
@test size(m(𝐱)) == (128, 1024, 5)
@@ -63,7 +63,7 @@ end
6363
Conv((1, ), 2=>64),
6464
OperatorKernel(ch, modes, FourierTransform, permuted=true)
6565
)
66-
@test repr(OperatorKernel(ch, modes, FourierTransform, permuted=true)) == "OperatorKernel(64 => 128, (16,), σ=identity, permuted=true)"
66+
@test repr(OperatorKernel(ch, modes, FourierTransform, permuted=true)) == "OperatorKernel(64 => 128, (16,), FourierTransform, σ=identity, permuted=true)"
6767

6868
𝐱 = rand(Float32, 2, 1024, 5)
6969
𝐱 = permutedims(𝐱, (2, 1, 3))

0 commit comments

Comments
 (0)