4545# matrix factorizations assume copy
4646# maybe: copy=false kwarg
4747
48- function matricize (a:: AbstractArray , biperm :: AbstractBlockPermutation{2} )
49- ndims (a) == length (biperm ) || throw (ArgumentError (" Invalid bipermutation" ))
50- return matricize (FusionStyle (a), a, biperm )
48+ function matricize (a:: AbstractArray , biperm_dest :: AbstractBlockPermutation{2} )
49+ ndims (a) == length (biperm_dest ) || throw (ArgumentError (" Invalid bipermutation" ))
50+ return matricize (FusionStyle (a), a, biperm_dest )
5151end
5252
5353function matricize (
54- style:: FusionStyle , a:: AbstractArray , biperm :: AbstractBlockPermutation{2}
54+ style:: FusionStyle , a:: AbstractArray , biperm_dest :: AbstractBlockPermutation{2}
5555)
56- a_perm = permuteblockeddims (a, biperm )
57- return matricize (style, a_perm, trivialperm (biperm ))
56+ a_perm = permuteblockeddims (a, biperm_dest )
57+ return matricize (style, a_perm, trivialperm (biperm_dest ))
5858end
5959
6060function matricize (
61- style:: FusionStyle , a:: AbstractArray , biperm :: BlockedTrivialPermutation{2}
61+ style:: FusionStyle , a:: AbstractArray , biperm_dest :: BlockedTrivialPermutation{2}
6262)
63- return throw (MethodError (matricize, Tuple{typeof (style),typeof (a),typeof (biperm )}))
63+ return throw (MethodError (matricize, Tuple{typeof (style),typeof (a),typeof (biperm_dest )}))
6464end
6565
6666# default is reshape
67- function matricize (:: ReshapeFusion , a:: AbstractArray , biperm:: BlockedTrivialPermutation{2} )
68- new_axes = fuseaxes (axes (a), biperm)
67+ function matricize (
68+ :: ReshapeFusion , a:: AbstractArray , biperm_dest:: BlockedTrivialPermutation{2}
69+ )
70+ new_axes = fuseaxes (axes (a), biperm_dest)
6971 return reshape (a, new_axes... )
7072end
7173
@@ -74,17 +76,20 @@ function matricize(a::AbstractArray, permblock1::Tuple, permblock2::Tuple)
7476end
7577
7678# ==================================== unmatricize =======================================
77- function unmatricize (m:: AbstractMatrix , axes, biperm:: AbstractBlockPermutation{2} )
78- length (axes) == length (biperm) || throw (ArgumentError (" axes do not match permutation" ))
79- return unmatricize (FusionStyle (m), m, axes, biperm)
79+ function unmatricize (m:: AbstractMatrix , axes_dest, invbiperm:: AbstractBlockPermutation{2} )
80+ length (axes_dest) == length (invbiperm) ||
81+ throw (ArgumentError (" axes do not match permutation" ))
82+ return unmatricize (FusionStyle (m), m, axes_dest, invbiperm)
8083end
8184
8285function unmatricize (
83- :: FusionStyle , m:: AbstractMatrix , axes, biperm :: AbstractBlockPermutation{2}
86+ :: FusionStyle , m:: AbstractMatrix , axes_dest, invbiperm :: AbstractBlockPermutation{2}
8487)
85- blocked_axes = axes[biperm]
86- a_perm = unmatricize (m, blocked_axes)
87- return permuteblockeddims (a_perm, invperm (biperm))
88+ blocked_axes = axes_dest[invbiperm]
89+ a12 = unmatricize (m, blocked_axes)
90+ biperm_dest = biperm (invperm (invbiperm), length_codomain (axes_dest))
91+
92+ return permuteblockeddims (a12, biperm_dest)
8893end
8994
9095function unmatricize (
@@ -108,10 +113,17 @@ function unmatricize(
108113 return unmatricize (m, blocked_axes)
109114end
110115
111- function unmatricize! (a , m:: AbstractMatrix , biperm :: AbstractBlockPermutation{2} )
112- ndims (a ) == length (biperm ) ||
116+ function unmatricize! (a_dest , m:: AbstractMatrix , invbiperm :: AbstractBlockPermutation{2} )
117+ ndims (a_dest ) == length (invbiperm ) ||
113118 throw (ArgumentError (" destination does not match permutation" ))
114- blocked_axes = axes (a)[biperm ]
119+ blocked_axes = axes (a_dest)[invbiperm ]
115120 a_perm = unmatricize (m, blocked_axes)
116- return permuteblockeddims! (a, a_perm, invperm (biperm))
121+ biperm_dest = biperm (invperm (invbiperm), length_codomain (axes (a_dest)))
122+ return permuteblockeddims! (a_dest, a_perm, biperm_dest)
123+ end
124+
125+ function unmatricize_add! (a_dest, a_dest_mat, invbiperm, α, β)
126+ a12 = unmatricize (a_dest_mat, axes (a_dest), invbiperm)
127+ a_dest .= α .* a12 .+ β .* a_dest
128+ return a_dest
117129end
0 commit comments