Skip to content

Commit 7078b4d

Browse files
committed
TensorAlgebra for TensorOperations
1 parent caa4ab3 commit 7078b4d

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,45 @@
11
module TensorAlgebraTensorOperationsExt
22

3+
using TensorAlgebra: TensorAlgebra, BlockedPermutation, Algorithm
4+
using TupleTools
5+
using TensorOperations
6+
using TensorOperations: AbstractBackend as TOAlgorithm
7+
8+
TensorAlgebra.Algorithm(backend::TOAlgorithm) = backend
9+
10+
trivtuple(n) = ntuple(identity, n)
11+
12+
function _index2tuple(p::BlockedPermutation{2})
13+
N₁, N₂ = blocklengths(p)
14+
return (
15+
TupleTools.getindices(Tuple(p), trivtuple(N₁)),
16+
TupleTools.getindices(Tuple(p), N₁ .+ trivtuple(N₂)),
17+
)
318
end
19+
20+
_blockedpermutation(p::Index2Tuple) = TensorAlgebra.blockedpermvcat(p...)
21+
22+
# Using TensorAlgebra implementations as TensorOperations backends
23+
# ----------------------------------------------------------------
24+
function TensorOperations.tensorcontract!(
25+
C::AbstractArray,
26+
A::AbstractArray,
27+
pA::Index2Tuple,
28+
conjA::Bool,
29+
B::AbstractArray,
30+
pB::Index2Tuple,
31+
conjB::Bool,
32+
pAB::Index2Tuple,
33+
α::Number,
34+
β::Number,
35+
backend::Algorithm,
36+
allocator,
37+
)
38+
bipermA = _blockedpermutation(pA)
39+
bipermB = _blockedpermutation(pB)
40+
bipermAB = _blockedpermutation(pAB)
41+
A′ = conjA ? conj(A) : A
42+
B′ = conjB ? conj(B) : B
43+
return TensorAlgebra.contract!(backend, C, bipermAB, A′, bipermA, B′, bipermB, α, β)
44+
end
45+

0 commit comments

Comments
 (0)