@@ -131,15 +131,15 @@ end
131
131
132
132
```jldoctest
133
133
julia> OperatorKernel(2=>5, (16, ), FourierTransform)
134
- OperatorKernel(2 => 5, (16,), σ=identity, permuted=false)
134
+ OperatorKernel(2 => 5, (16,), FourierTransform, σ=identity, permuted=false)
135
135
136
136
julia> using Flux
137
137
138
138
julia> OperatorKernel(2=>5, (16, ), FourierTransform, relu)
139
- OperatorKernel(2 => 5, (16,), σ=relu, permuted=false)
139
+ OperatorKernel(2 => 5, (16,), FourierTransform, σ=relu, permuted=false)
140
140
141
141
julia> OperatorKernel(2=>5, (16, ), FourierTransform, relu, permuted=true)
142
- OperatorKernel(2 => 5, (16,), σ=relu, permuted=true)
142
+ OperatorKernel(2 => 5, (16,), FourierTransform, σ=relu, permuted=true)
143
143
```
144
144
"""
145
145
function OperatorKernel (
@@ -163,6 +163,7 @@ function Base.show(io::IO, l::OperatorKernel)
163
163
" OperatorKernel(" *
164
164
" $(l. conv. in_channel) => $(l. conv. out_channel) , " *
165
165
" $(l. conv. modes) , " *
166
+ " $(nameof (typeof (l. conv. transform))) , " *
166
167
" σ=$(string (l. σ)) , " *
167
168
" permuted=$(ispermuted (l. conv)) " *
168
169
" )"
@@ -185,9 +186,9 @@ einsum(𝐱₁, 𝐱₂) = @tullio 𝐲[m, o, b] := 𝐱₁[m, i, b] * 𝐱₂[m
185
186
function apply_pattern (𝐱_truncated, 𝐰)
186
187
x_size = size (𝐱_truncated) # [m.modes..., in_chs, batch]
187
188
188
- 𝐱_flattened = reshape (𝐱_truncated, :, x_size[end - 1 : end ]. .. ) # [prod(m.modes), out_chs , batch], only 3-dims
189
+ 𝐱_flattened = reshape (𝐱_truncated, :, x_size[end - 1 : end ]. .. ) # [prod(m.modes), in_chs , batch], only 3-dims
189
190
𝐱_weighted = einsum (𝐱_flattened, 𝐰) # [prod(m.modes), out_chs, batch], only 3-dims
190
- 𝐱_shaped = reshape (𝐱_weighted, x_size[1 : end - 2 ]. .. , size (𝐱_weighted, 2 ), size (𝐱_weighted, 3 ) ) # [m.modes..., out_chs, batch]
191
+ 𝐱_shaped = reshape (𝐱_weighted, x_size[1 : end - 2 ]. .. , size (𝐱_weighted)[ 2 : 3 ] . .. ) # [m.modes..., out_chs, batch]
191
192
192
193
return 𝐱_shaped
193
194
end
0 commit comments