Skip to content

Commit d2c7769

Browse files
committed
Refactor unmatricize
1 parent 64191ab commit d2c7769

File tree

1 file changed

+92
-27
lines changed

1 file changed

+92
-27
lines changed

src/matricize.jl

Lines changed: 92 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -141,18 +141,36 @@ function matricize(::ReshapeFusion, a::AbstractArray, length1::Val, length2::Val
141141
end
142142

143143
# ==================================== unmatricize =======================================
144-
function unmatricize(m::AbstractMatrix, axes_dest, invbiperm::AbstractBlockPermutation{2})
145-
return unmatricize(FusionStyle(m), m, axes_dest, invbiperm)
144+
function unmatricize(
145+
m::AbstractMatrix,
146+
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
147+
domain_axes::Tuple{Vararg{AbstractUnitRange}},
148+
)
149+
return unmatricize(FusionStyle(m), m, codomain_axes, domain_axes)
146150
end
151+
# This is the primary function that should be overloaded for new fusion styles.
147152
function unmatricize(
148-
style::FusionStyle, m::AbstractMatrix, axes_dest, invbiperm::AbstractBlockPermutation{2}
153+
style::FusionStyle, m::AbstractMatrix,
154+
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
155+
domain_axes::Tuple{Vararg{AbstractUnitRange}},
149156
)
150-
length(axes_dest) == length(invbiperm) ||
151-
throw(ArgumentError("axes do not match permutation"))
152-
blocked_axes = axes_dest[invbiperm]
153-
a12 = unmatricize(style, m, blocked_axes)
154-
biperm_dest = biperm(invperm(invbiperm), length_codomain(axes_dest))
155-
return permuteblockeddims(a12, biperm_dest)
157+
return throw(
158+
MethodError(
159+
unmatricize,
160+
Tuple{
161+
typeof(style), typeof(m), typeof(codomain_axes), typeof(domain_axes)
162+
},
163+
)
164+
)
165+
end
166+
167+
# Implementation using reshape.
168+
function unmatricize(
169+
style::ReshapeFusion, m::AbstractMatrix,
170+
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
171+
domain_axes::Tuple{Vararg{AbstractUnitRange}},
172+
)
173+
return reshape(m, (codomain_axes..., domain_axes...))
156174
end
157175

158176
function unmatricize(
@@ -162,38 +180,53 @@ function unmatricize(
162180
return unmatricize(FusionStyle(m), m, blocked_axes)
163181
end
164182
function unmatricize(
165-
::ReshapeFusion,
183+
style::FusionStyle,
166184
m::AbstractMatrix,
167185
blocked_axes::BlockedTuple{2, <:Any, <:Tuple{Vararg{AbstractUnitRange}}},
168186
)
169-
return reshape(m, Tuple(blocked_axes)...)
187+
return unmatricize(style, m, blocks(blocked_axes)...)
170188
end
171189

172190
function unmatricize(
173-
m::AbstractMatrix,
174-
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
175-
domain_axes::Tuple{Vararg{AbstractUnitRange}},
191+
m::AbstractMatrix, axes_dest,
192+
invperm1::Tuple{Vararg{Int}}, invperm2::Tuple{Vararg{Int}},
176193
)
177-
return unmatricize(FusionStyle(m), m, codomain_axes, domain_axes)
194+
return unmatricize(FusionStyle(m), m, axes_dest, invperm1, invperm2)
178195
end
179196
function unmatricize(
180-
style::FusionStyle, m::AbstractMatrix,
181-
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
182-
domain_axes::Tuple{Vararg{AbstractUnitRange}},
197+
style::FusionStyle, m::AbstractMatrix, axes_dest,
198+
invperm1::Tuple{Vararg{Int}}, invperm2::Tuple{Vararg{Int}},
183199
)
184-
blocked_axes = tuplemortar((codomain_axes, domain_axes))
185-
return unmatricize(style, m, blocked_axes)
200+
invbiperm = permmortar((invperm1, invperm2))
201+
length(axes_dest) == length(invbiperm) ||
202+
throw(ArgumentError("axes do not match permutation"))
203+
blocked_axes = axes_dest[invbiperm]
204+
a12 = unmatricize(style, m, blocked_axes)
205+
biperm_dest = biperm(invperm(invbiperm), length_codomain(axes_dest))
206+
return permuteblockeddims(a12, biperm_dest)
207+
end
208+
209+
function unmatricize(m::AbstractMatrix, axes_dest, invbiperm::AbstractBlockPermutation{2})
210+
return unmatricize(FusionStyle(m), m, axes_dest, invbiperm)
211+
end
212+
function unmatricize(
213+
style::FusionStyle, m::AbstractMatrix, axes_dest,
214+
invbiperm::AbstractBlockPermutation{2}
215+
)
216+
return unmatricize(style, m, axes_dest, blocks(invbiperm)...)
186217
end
187218

188219
function unmatricize!(
189-
a_dest::AbstractArray, m::AbstractMatrix, invbiperm::AbstractBlockPermutation{2}
220+
a_dest::AbstractArray, m::AbstractMatrix,
221+
invperm1::Tuple{Vararg{Int}}, invperm2::Tuple{Vararg{Int}},
190222
)
191-
return unmatricize!(FusionStyle(m), a_dest, m, invbiperm)
223+
return unmatricize!(FusionStyle(m), a_dest, m, invperm1, invperm2)
192224
end
193225
function unmatricize!(
194226
style::FusionStyle, a_dest::AbstractArray, m::AbstractMatrix,
195-
invbiperm::AbstractBlockPermutation{2},
227+
invperm1::Tuple{Vararg{Int}}, invperm2::Tuple{Vararg{Int}},
196228
)
229+
invbiperm = permmortar((invperm1, invperm2))
197230
ndims(a_dest) == length(invbiperm) ||
198231
throw(ArgumentError("destination does not match permutation"))
199232
blocked_axes = axes(a_dest)[invbiperm]
@@ -202,16 +235,48 @@ function unmatricize!(
202235
return permuteblockeddims!(a_dest, a_perm, biperm_dest)
203236
end
204237

238+
function unmatricize!(
239+
a_dest::AbstractArray, m::AbstractMatrix, invbiperm::AbstractBlockPermutation{2}
240+
)
241+
return unmatricize!(FusionStyle(m), a_dest, m, invbiperm)
242+
end
243+
function unmatricize!(
244+
style::FusionStyle, a_dest::AbstractArray, m::AbstractMatrix,
245+
invbiperm::AbstractBlockPermutation{2},
246+
)
247+
return unmatricize!(style, a_dest, m, blocks(invbiperm)...)
248+
end
249+
205250
function unmatricizeadd!(
206-
a_dest::AbstractArray, m::AbstractMatrix, invbiperm::AbstractBlockPermutation{2},
251+
a_dest::AbstractArray, m::AbstractMatrix,
252+
invperm1::Tuple{Vararg{Int}}, invperm2::Tuple{Vararg{Int}},
207253
α::Number, β::Number
208254
)
209-
return unmatricizeadd!(FusionStyle(a_dest), a_dest, m, invbiperm, α, β)
255+
return unmatricizeadd!(FusionStyle(a_dest), a_dest, m, invperm1, invperm2, α, β)
210256
end
211257
function unmatricizeadd!(
212-
style::FusionStyle, a_dest::AbstractArray, m::AbstractMatrix, invbiperm, α, β
258+
style::FusionStyle, a_dest::AbstractArray, m::AbstractMatrix,
259+
invperm1::Tuple{Vararg{Int}}, invperm2::Tuple{Vararg{Int}},
260+
α::Number, β::Number,
213261
)
214-
a12 = unmatricize(style, m, axes(a_dest), invbiperm)
262+
a12 = unmatricize(style, m, axes(a_dest), invperm1, invperm2)
215263
a_dest .= α .* a12 .+ β .* a_dest
216264
return a_dest
217265
end
266+
267+
function unmatricizeadd!(
268+
a_dest::AbstractArray, m::AbstractMatrix,
269+
invbiperm::AbstractBlockPermutation{2},
270+
α::Number, β::Number
271+
)
272+
return unmatricizeadd!(FusionStyle(a_dest), a_dest, m, invbiperm, α, β)
273+
end
274+
function unmatricizeadd!(
275+
style::FusionStyle, a_dest::AbstractArray, m::AbstractMatrix,
276+
invbiperm::AbstractBlockPermutation{2},
277+
α::Number, β::Number,
278+
)
279+
return unmatricizeadd!(
280+
style, a_dest, m, blocks(invbiperm)..., α, β
281+
)
282+
end

0 commit comments

Comments
 (0)