Skip to content

Commit a1344e8

Browse files
committed
copy kwarg
1 parent 99ca8f2 commit a1344e8

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

src/contract/contract_matricize/contract.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,10 @@ function contract_inplace!(
3535
)
3636
biperm_dest = invbiperm(biperm_out, Val(first(blocklengths(biperm1))))
3737
check_input(contract, a_dest, biperm_dest, a1, biperm1, a2, biperm2)
38-
a_dest_mat = matricize(a_dest, biperm_dest)
39-
a1_mat = matricize(a1, biperm1)
40-
a2_mat = matricize(a2, biperm2)
38+
a_dest_mat = matricize(a_dest, biperm_dest; copy=false)
39+
a1_mat = matricize(a1, biperm1; copy=false)
40+
a2_mat = matricize(a2, biperm2; copy=false)
4141
mul!(a_dest_mat, a1_mat, a2_mat, α, β)
42-
unmatricize!(a_dest, a_dest_mat, biperm_dest) # TODO remove: need no copy in matricize
4342
return a_dest
4443
end
4544

src/matricize.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,23 @@ end
4545
# matrix factorizations assume copy
4646
# maybe: copy=false kwarg
4747

48-
function matricize(a::AbstractArray, biperm::AbstractBlockPermutation{2})
48+
function matricize(a::AbstractArray, biperm::AbstractBlockPermutation{2}; copy=false)
4949
ndims(a) == length(biperm) || throw(ArgumentError("Invalid bipermutation"))
50-
return matricize(FusionStyle(a), a, biperm)
50+
return matricize(FusionStyle(a), a, biperm; copy)
5151
end
5252

5353
function matricize(
54-
style::FusionStyle, a::AbstractArray, biperm::AbstractBlockPermutation{2}
54+
style::FusionStyle, a::AbstractArray, biperm::AbstractBlockPermutation{2}; copy=false
5555
)
56+
if istrivialperm(Tuple(biperm)) && !copy
57+
return matricize(style, a, trivialperm(biperm))
58+
end
5659
a_perm = permuteblockeddims(a, biperm)
5760
return matricize(style, a_perm, trivialperm(biperm))
5861
end
5962

6063
function matricize(
61-
style::FusionStyle, a::AbstractArray, biperm::BlockedTrivialPermutation{2}
64+
style::FusionStyle, a::AbstractArray, biperm::BlockedTrivialPermutation{2}; copy=false
6265
)
6366
return throw(MethodError(matricize, Tuple{typeof(style),typeof(a),typeof(biperm)}))
6467
end
@@ -69,8 +72,8 @@ function matricize(::ReshapeFusion, a::AbstractArray, biperm::BlockedTrivialPerm
6972
return reshape(a, new_axes...)
7073
end
7174

72-
function matricize(a::AbstractArray, permblock1::Tuple, permblock2::Tuple)
73-
return matricize(a, blockedpermvcat(permblock1, permblock2; length=Val(ndims(a))))
75+
function matricize(a::AbstractArray, permblock1::Tuple, permblock2::Tuple; copy=false)
76+
return matricize(a, blockedpermvcat(permblock1, permblock2; length=Val(ndims(a))); copy)
7477
end
7578

7679
# ==================================== unmatricize =======================================

0 commit comments

Comments
 (0)