Skip to content

Commit 5cfaaf2

Browse files
committed
TensorAlgebra for TensorOperations
1 parent 7855389 commit 5cfaaf2

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,47 @@
11
module TensorAlgebraTensorOperationsExt
22

3+
using TensorAlgebra: TensorAlgebra, BlockedPermutation, Algorithm
4+
using TupleTools
5+
using TensorOperations
6+
using TensorOperations: AbstractBackend as TOAlgorithm
7+
8+
TensorAlgebra.Algorithm(backend::TOAlgorithm) = backend
9+
10+
trivtuple(n) = ntuple(identity, n)
11+
12+
function _index2tuple(p::BlockedPermutation{2})
13+
N₁, N₂ = blocklengths(p)
14+
return (
15+
TupleTools.getindices(Tuple(p), trivtuple(N₁)),
16+
TupleTools.getindices(Tuple(p), N₁ .+ trivtuple(N₂)),
17+
)
318
end
19+
20+
_blockedpermutation(p::Index2Tuple) = TensorAlgebra.blockedpermvcat(p...)
21+
22+
# Using TensorAlgebra implementations as TensorOperations backends
23+
# ----------------------------------------------------------------
24+
function TensorOperations.tensorcontract!(
25+
C::AbstractArray,
26+
A::AbstractArray,
27+
pA::Index2Tuple,
28+
conjA::Bool,
29+
B::AbstractArray,
30+
pB::Index2Tuple,
31+
conjB::Bool,
32+
pAB::Index2Tuple,
33+
α::Number,
34+
β::Number,
35+
backend::Algorithm,
36+
allocator,
37+
)
38+
bipermA = _blockedpermutation(pA)
39+
bipermB = _blockedpermutation(pB)
40+
ipAB = invperm((pAB[1]..., pAB[2]...))
41+
bipermAB = _blockedpermutation((TupleTools.getindices(ipAB, trivtuple(length(pA[1]))),
42+
TupleTools.getindices(ipAB, trivtuple(length(pB[2])) .+ length(pA[1]))))
43+
A′ = conjA ? conj(A) : A
44+
B′ = conjB ? conj(B) : B
45+
return TensorAlgebra.contract!(backend, C, bipermAB, A′, bipermA, B′, bipermB, α, β)
46+
end
47+

0 commit comments

Comments
 (0)