Skip to content

Commit 99ca8f2

Browse files
committed
use unmatricize_add!
1 parent 6e71131 commit 99ca8f2

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

src/contract/contract_matricize/contract.jl

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,28 @@
11
using LinearAlgebra: mul!
22

3+
function isinplace(::AbstractArray, biperm_out)
4+
return istrivialperm(Tuple(biperm_out))
5+
end
6+
37
function contract!(
8+
alg::Matricize,
9+
a_dest::AbstractArray,
10+
biperm_out::AbstractBlockPermutation{2},
11+
a1::AbstractArray,
12+
biperm1::AbstractBlockPermutation{2},
13+
a2::AbstractArray,
14+
biperm2::AbstractBlockPermutation{2},
15+
α::Number,
16+
β::Number,
17+
)
18+
if isinplace(a_dest, biperm_out)
19+
return contract_inplace!(alg, a_dest, biperm_out, a1, biperm1, a2, biperm2, α, β)
20+
else
21+
return contract_outofplace!(alg, a_dest, biperm_out, a1, biperm1, a2, biperm2, α, β)
22+
end
23+
end
24+
25+
function contract_inplace!(
426
::Matricize,
527
a_dest::AbstractArray,
628
biperm_out::AbstractBlockPermutation{2},
@@ -17,6 +39,26 @@ function contract!(
1739
a1_mat = matricize(a1, biperm1)
1840
a2_mat = matricize(a2, biperm2)
1941
mul!(a_dest_mat, a1_mat, a2_mat, α, β)
20-
unmatricize!(a_dest, a_dest_mat, biperm_dest)
42+
unmatricize!(a_dest, a_dest_mat, biperm_dest) # TODO remove: need no copy in matricize
43+
return a_dest
44+
end
45+
46+
function contract_outofplace!(
47+
::Matricize,
48+
a_dest::AbstractArray,
49+
biperm_out::AbstractBlockPermutation{2},
50+
a1::AbstractArray,
51+
biperm1::AbstractBlockPermutation{2},
52+
a2::AbstractArray,
53+
biperm2::AbstractBlockPermutation{2},
54+
α::Number,
55+
β::Number,
56+
)
57+
biperm_dest = invbiperm(biperm_out, Val(first(blocklengths(biperm1))))
58+
check_input(contract, a_dest, biperm_dest, a1, biperm1, a2, biperm2)
59+
a1_mat = matricize(a1, biperm1)
60+
a2_mat = matricize(a2, biperm2)
61+
a_dest_mat = a1_mat * a2_mat
62+
unmatricize_add!(a_dest, a_dest_mat, biperm_dest, α, β)
2163
return a_dest
2264
end

src/matricize.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,7 @@ function unmatricize!(a, m::AbstractMatrix, biperm::AbstractBlockPermutation{2})
115115
a_perm = unmatricize(m, blocked_axes)
116116
return permuteblockeddims!(a, a_perm, invperm(biperm))
117117
end
118+
119+
function unmatricize_add!(a_dest, a_dest_mat, biperm_dest, α, β)
120+
return mul!(a_dest, 1.0, unmatricize(a_dest_mat, axes(a_dest), biperm_dest), α, β)
121+
end

0 commit comments

Comments
 (0)