Skip to content

Commit 6dc4a28

Browse files
lkdvosJutho
authored andcommitted
Add allocation hook for mul of tensors
1 parent 234a973 commit 6dc4a28

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

src/tensors/linalg.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,27 @@ Base.:\(α::Number, t::AbstractTensorMap) = *(t, one(scalartype(t)) / α)
1818
LinearAlgebra.normalize!(t::AbstractTensorMap, p::Real=2) = scale!(t, inv(norm(t, p)))
1919
LinearAlgebra.normalize(t::AbstractTensorMap, p::Real=2) = scale(t, inv(norm(t, p)))
2020

21+
# destination allocation for matrix multiplication
22+
function compose_dest(A::AbstractTensorMap, B::AbstractTensorMap)
23+
TC = TO.promote_contract(scalartype(A), scalartype(B), One)
24+
pA = (codomainind(A), domainind(A))
25+
pB = (codomainind(B), domainind(B))
26+
pAB = (codomainind(A), ntuple(i -> i + numout(A), numin(B)))
27+
return TO.tensoralloc_contract(TC,
28+
A, pA, false,
29+
B, pB, false,
30+
pAB, Val(false))
31+
end
32+
2133
"""
2234
compose(t1::AbstractTensorMap, t2::AbstractTensorMap) -> AbstractTensorMap
2335
2436
Return the `AbstractTensorMap` that implements the composition of the two tensor maps `t1`
2537
and `t2`.
2638
"""
27-
function compose(t1::AbstractTensorMap, t2::AbstractTensorMap)
28-
return mul!(similar(t1, promote_type(scalartype(t1), scalartype(t2)),
29-
compose(space(t1), space(t2))), t1, t2)
39+
function compose(A::AbstractTensorMap, B::AbstractTensorMap)
40+
C = compose_dest(A, B)
41+
return mul!(C, A, B)
3042
end
3143
Base.:*(t1::AbstractTensorMap, t2::AbstractTensorMap) = compose(t1, t2)
3244

0 commit comments

Comments
 (0)