4545# matrix factorizations assume copy
4646# maybe: copy=false kwarg
4747
48- function matricize (a:: AbstractArray , biperm:: AbstractBlockPermutation{2} ; copy = false )
48+ function matricize (a:: AbstractArray , biperm:: AbstractBlockPermutation{2} )
4949 ndims (a) == length (biperm) || throw (ArgumentError (" Invalid bipermutation" ))
50- return matricize (FusionStyle (a), a, biperm; copy )
50+ return matricize (FusionStyle (a), a, biperm)
5151end
5252
5353function matricize (
54- style:: FusionStyle , a:: AbstractArray , biperm:: AbstractBlockPermutation{2} ; copy = false
54+ style:: FusionStyle , a:: AbstractArray , biperm:: AbstractBlockPermutation{2}
5555)
56- if istrivialperm (Tuple (biperm)) && ! copy
57- return matricize (style, a, trivialperm (biperm))
58- end
5956 a_perm = permuteblockeddims (a, biperm)
6057 return matricize (style, a_perm, trivialperm (biperm))
6158end
6259
6360function matricize (
64- style:: FusionStyle , a:: AbstractArray , biperm:: BlockedTrivialPermutation{2} ; copy = false
61+ style:: FusionStyle , a:: AbstractArray , biperm:: BlockedTrivialPermutation{2}
6562)
6663 return throw (MethodError (matricize, Tuple{typeof (style),typeof (a),typeof (biperm)}))
6764end
@@ -72,22 +69,29 @@ function matricize(::ReshapeFusion, a::AbstractArray, biperm::BlockedTrivialPerm
7269 return reshape (a, new_axes... )
7370end
7471
75- function matricize (a:: AbstractArray , permblock1:: Tuple , permblock2:: Tuple ; copy = false )
76- return matricize (a, blockedpermvcat (permblock1, permblock2; length= Val (ndims (a))); copy )
72+ function matricize (a:: AbstractArray , permblock1:: Tuple , permblock2:: Tuple )
73+ return matricize (a, blockedpermvcat (permblock1, permblock2; length= Val (ndims (a))))
7774end
7875
7976# ==================================== unmatricize =======================================
80- function unmatricize (m:: AbstractMatrix , axes, biperm:: AbstractBlockPermutation{2} )
81- length (axes) == length (biperm) || throw (ArgumentError (" axes do not match permutation" ))
82- return unmatricize (FusionStyle (m), m, axes, biperm)
77+ function unmatricize (
78+ m:: AbstractMatrix , axes_dest, biperm_dest_to_a12:: AbstractBlockPermutation{2}
79+ )
80+ length (axes_dest) == length (biperm_dest_to_a12) ||
81+ throw (ArgumentError (" axes do not match permutation" ))
82+ return unmatricize (FusionStyle (m), m, axes_dest, biperm_dest_to_a12)
8383end
8484
8585function unmatricize (
86- :: FusionStyle , m:: AbstractMatrix , axes, biperm:: AbstractBlockPermutation{2}
86+ :: FusionStyle ,
87+ m:: AbstractMatrix ,
88+ axes_dest,
89+ biperm_dest_to_a12:: AbstractBlockPermutation{2} ,
8790)
88- blocked_axes = axes[biperm]
89- a_perm = unmatricize (m, blocked_axes)
90- return permuteblockeddims (a_perm, invperm (biperm))
91+ blocked_axes = axes_dest[biperm_dest_to_a12]
92+ a12 = unmatricize (m, blocked_axes)
93+ biperm_a12_to_dest = invbiperm (biperm_dest_to_a12, axes_dest)
94+ return permuteblockeddims (a12, biperm_a12_to_dest)
9195end
9296
9397function unmatricize (
@@ -111,14 +115,19 @@ function unmatricize(
111115 return unmatricize (m, blocked_axes)
112116end
113117
114- function unmatricize! (a, m:: AbstractMatrix , biperm:: AbstractBlockPermutation{2} )
115- ndims (a) == length (biperm) ||
118+ function unmatricize! (
119+ a_dest, m:: AbstractMatrix , biperm_dest_to_a12:: AbstractBlockPermutation{2}
120+ )
121+ ndims (a_dest) == length (biperm_dest_to_a12) ||
116122 throw (ArgumentError (" destination does not match permutation" ))
117- blocked_axes = axes (a)[biperm ]
123+ blocked_axes = axes (a_dest)[biperm_dest_to_a12 ]
118124 a_perm = unmatricize (m, blocked_axes)
119- return permuteblockeddims! (a, a_perm, invperm (biperm))
125+ biperm_a12_to_dest = invbiperm (biperm_dest_to_a12, axes (a_dest))
126+ return permuteblockeddims! (a_dest, a_perm, biperm_a12_to_dest)
120127end
121128
122- function unmatricize_add! (a_dest, a_dest_mat, biperm_dest, α, β)
123- return mul! (a_dest, 1.0 , unmatricize (a_dest_mat, axes (a_dest), biperm_dest), α, β)
129+ function unmatricize_add! (a_dest, a_dest_mat, biperm_dest_to_a12, α, β)
130+ a12 = unmatricize (a_dest_mat, axes (a_dest), biperm_dest_to_a12)
131+ a_dest .= α .* a12 .+ β .* a_dest
132+ return a_dest
124133end
0 commit comments