@@ -141,18 +141,36 @@ function matricize(::ReshapeFusion, a::AbstractArray, length1::Val, length2::Val
141141end
142142
143143# ==================================== unmatricize =======================================
144- function unmatricize (m:: AbstractMatrix , axes_dest, invbiperm:: AbstractBlockPermutation{2} )
145- return unmatricize (FusionStyle (m), m, axes_dest, invbiperm)
144+ function unmatricize (
145+ m:: AbstractMatrix ,
146+ codomain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
147+ domain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
148+ )
149+ return unmatricize (FusionStyle (m), m, codomain_axes, domain_axes)
146150end
151+ # This is the primary function that should be overloaded for new fusion styles.
147152function unmatricize (
148- style:: FusionStyle , m:: AbstractMatrix , axes_dest, invbiperm:: AbstractBlockPermutation{2}
153+ style:: FusionStyle , m:: AbstractMatrix ,
154+ codomain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
155+ domain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
149156 )
150- length (axes_dest) == length (invbiperm) ||
151- throw (ArgumentError (" axes do not match permutation" ))
152- blocked_axes = axes_dest[invbiperm]
153- a12 = unmatricize (style, m, blocked_axes)
154- biperm_dest = biperm (invperm (invbiperm), length_codomain (axes_dest))
155- return permuteblockeddims (a12, biperm_dest)
157+ return throw (
158+ MethodError (
159+ unmatricize,
160+ Tuple{
161+ typeof (style), typeof (m), typeof (codomain_axes), typeof (domain_axes)
162+ },
163+ )
164+ )
165+ end
166+
167+ # Implementation using reshape.
168+ function unmatricize (
169+ style:: ReshapeFusion , m:: AbstractMatrix ,
170+ codomain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
171+ domain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
172+ )
173+ return reshape (m, (codomain_axes... , domain_axes... ))
156174end
157175
158176function unmatricize (
@@ -162,38 +180,53 @@ function unmatricize(
162180 return unmatricize (FusionStyle (m), m, blocked_axes)
163181end
164182function unmatricize (
165- :: ReshapeFusion ,
183+ style :: FusionStyle ,
166184 m:: AbstractMatrix ,
167185 blocked_axes:: BlockedTuple{2, <:Any, <:Tuple{Vararg{AbstractUnitRange}}} ,
168186 )
169- return reshape ( m, Tuple (blocked_axes)... )
187+ return unmatricize (style, m, blocks (blocked_axes)... )
170188end
171189
172190function unmatricize (
173- m:: AbstractMatrix ,
174- codomain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
175- domain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
191+ m:: AbstractMatrix , axes_dest,
192+ invperm1:: Tuple{Vararg{Int}} , invperm2:: Tuple{Vararg{Int}} ,
176193 )
177- return unmatricize (FusionStyle (m), m, codomain_axes, domain_axes )
194+ return unmatricize (FusionStyle (m), m, axes_dest, invperm1, invperm2 )
178195end
179196function unmatricize (
180- style:: FusionStyle , m:: AbstractMatrix ,
181- codomain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
182- domain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
197+ style:: FusionStyle , m:: AbstractMatrix , axes_dest,
198+ invperm1:: Tuple{Vararg{Int}} , invperm2:: Tuple{Vararg{Int}} ,
183199 )
184- blocked_axes = tuplemortar ((codomain_axes, domain_axes))
185- return unmatricize (style, m, blocked_axes)
200+ invbiperm = permmortar ((invperm1, invperm2))
201+ length (axes_dest) == length (invbiperm) ||
202+ throw (ArgumentError (" axes do not match permutation" ))
203+ blocked_axes = axes_dest[invbiperm]
204+ a12 = unmatricize (style, m, blocked_axes)
205+ biperm_dest = biperm (invperm (invbiperm), length_codomain (axes_dest))
206+ return permuteblockeddims (a12, biperm_dest)
207+ end
208+
209+ function unmatricize (m:: AbstractMatrix , axes_dest, invbiperm:: AbstractBlockPermutation{2} )
210+ return unmatricize (FusionStyle (m), m, axes_dest, invbiperm)
211+ end
212+ function unmatricize (
213+ style:: FusionStyle , m:: AbstractMatrix , axes_dest,
214+ invbiperm:: AbstractBlockPermutation{2}
215+ )
216+ return unmatricize (style, m, axes_dest, blocks (invbiperm)... )
186217end
187218
188219function unmatricize! (
189- a_dest:: AbstractArray , m:: AbstractMatrix , invbiperm:: AbstractBlockPermutation{2}
220+ a_dest:: AbstractArray , m:: AbstractMatrix ,
221+ invperm1:: Tuple{Vararg{Int}} , invperm2:: Tuple{Vararg{Int}} ,
190222 )
191- return unmatricize! (FusionStyle (m), a_dest, m, invbiperm )
223+ return unmatricize! (FusionStyle (m), a_dest, m, invperm1, invperm2 )
192224end
193225function unmatricize! (
194226 style:: FusionStyle , a_dest:: AbstractArray , m:: AbstractMatrix ,
195- invbiperm :: AbstractBlockPermutation{2 } ,
227+ invperm1 :: Tuple{Vararg{Int}} , invperm2 :: Tuple{Vararg{Int} } ,
196228 )
229+ invbiperm = permmortar ((invperm1, invperm2))
197230 ndims (a_dest) == length (invbiperm) ||
198231 throw (ArgumentError (" destination does not match permutation" ))
199232 blocked_axes = axes (a_dest)[invbiperm]
@@ -202,16 +235,48 @@ function unmatricize!(
202235 return permuteblockeddims! (a_dest, a_perm, biperm_dest)
203236end
204237
238+ function unmatricize! (
239+ a_dest:: AbstractArray , m:: AbstractMatrix , invbiperm:: AbstractBlockPermutation{2}
240+ )
241+ return unmatricize! (FusionStyle (m), a_dest, m, invbiperm)
242+ end
243+ function unmatricize! (
244+ style:: FusionStyle , a_dest:: AbstractArray , m:: AbstractMatrix ,
245+ invbiperm:: AbstractBlockPermutation{2} ,
246+ )
247+ return unmatricize! (style, a_dest, m, blocks (invbiperm)... )
248+ end
249+
205250function unmatricizeadd! (
206- a_dest:: AbstractArray , m:: AbstractMatrix , invbiperm:: AbstractBlockPermutation{2} ,
251+ a_dest:: AbstractArray , m:: AbstractMatrix ,
252+ invperm1:: Tuple{Vararg{Int}} , invperm2:: Tuple{Vararg{Int}} ,
207253 α:: Number , β:: Number
208254 )
209- return unmatricizeadd! (FusionStyle (a_dest), a_dest, m, invbiperm , α, β)
255+ return unmatricizeadd! (FusionStyle (a_dest), a_dest, m, invperm1, invperm2 , α, β)
210256end
211257function unmatricizeadd! (
212- style:: FusionStyle , a_dest:: AbstractArray , m:: AbstractMatrix , invbiperm, α, β
258+ style:: FusionStyle , a_dest:: AbstractArray , m:: AbstractMatrix ,
259+ invperm1:: Tuple{Vararg{Int}} , invperm2:: Tuple{Vararg{Int}} ,
260+ α:: Number , β:: Number ,
213261 )
214- a12 = unmatricize (style, m, axes (a_dest), invbiperm )
262+ a12 = unmatricize (style, m, axes (a_dest), invperm1, invperm2 )
215263 a_dest .= α .* a12 .+ β .* a_dest
216264 return a_dest
217265end
266+
267+ function unmatricizeadd! (
268+ a_dest:: AbstractArray , m:: AbstractMatrix ,
269+ invbiperm:: AbstractBlockPermutation{2} ,
270+ α:: Number , β:: Number
271+ )
272+ return unmatricizeadd! (FusionStyle (a_dest), a_dest, m, invbiperm, α, β)
273+ end
274+ function unmatricizeadd! (
275+ style:: FusionStyle , a_dest:: AbstractArray , m:: AbstractMatrix ,
276+ invbiperm:: AbstractBlockPermutation{2} ,
277+ α:: Number , β:: Number ,
278+ )
279+ return unmatricizeadd! (
280+ style, a_dest, m, blocks (invbiperm)... , α, β
281+ )
282+ end
0 commit comments