Skip to content

Commit 3ee9224

Browse files
committed
TensorOperations for TensorAlgebra
1 parent 5cfaaf2 commit 3ee9224

File tree

1 file changed

+77
-2
lines changed

1 file changed

+77
-2
lines changed

ext/TensorAlgebraTensorOperationsExt.jl

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,78 @@ end
1919

2020
_blockedpermutation(p::Index2Tuple) = TensorAlgebra.blockedpermvcat(p...)
2121

22+
# Using TensorOperations backends as TensorAlgebra implementations
23+
# ----------------------------------------------------------------
24+
25+
# not in-place
26+
function TensorAlgebra.contract(
27+
backend::TOAlgorithm,
28+
bipermAB::BlockedPermutation,
29+
A::AbstractArray,
30+
bipermA::BlockedPermutation,
31+
B::AbstractArray,
32+
bipermB::BlockedPermutation,
33+
α::Number,
34+
)
35+
pA = _index2tuple(bipermA)
36+
pB = _index2tuple(bipermB)
37+
38+
# TODO: this assumes biperm of output because not enough information!
39+
ipermAB = invperm(Tuple(bipermAB))
40+
pAB = (TupleTools.getindices(ipermAB, length(ipermAB)), ())
41+
42+
return tensorcontract(A, pA, false, B, pB, false, pAB, α, backend)
43+
end
44+
45+
function TensorAlgebra.contract(
46+
backend::TOAlgorithm,
47+
labelsC,
48+
A::AbstractArray,
49+
labelsA,
50+
B::AbstractArray,
51+
labelsB,
52+
α::Number,
53+
)
54+
return tensorcontract(labelsC, A, labelsA, B, labelsB, α; backend)
55+
end
56+
57+
# in-place
58+
function TensorAlgebra.contract!(
59+
backend::TOAlgorithm,
60+
C::AbstractArray,
61+
bipermAB::BlockedPermutation,
62+
A::AbstractArray,
63+
bipermA::BlockedPermutation,
64+
B::AbstractArray,
65+
bipermB::BlockedPermutation,
66+
α::Number,
67+
β::Number,
68+
)
69+
pA = _index2tuple(bipermA)
70+
pB = _index2tuple(bipermB)
71+
72+
# TODO: this assumes biperm of output because not enough information!
73+
ipermAB = invperm(Tuple(bipermAB))
74+
pAB = (TupleTools.getindices(ipermAB, length(ipermAB)), ())
75+
76+
return tensorcontract!(C, A, pA, false, B, pB, false, pAB, α, β, backend)
77+
end
78+
79+
function TensorAlgebra.contract!(
80+
backend::TOAlgorithm,
81+
C::AbstractArray,
82+
labelsC,
83+
A::AbstractArray,
84+
labelsA,
85+
B::AbstractArray,
86+
labelsB,
87+
α::Number,
88+
β::Number,
89+
)
90+
pA, pB, pAB = TensorOperations.contract_indices(labelsA, labelsB, labelsC)
91+
return TensorOperations.tensorcontract!(C, A, pA, false, B, pB, false, pAB, α, β, backend)
92+
end
93+
2294
# Using TensorAlgebra implementations as TensorOperations backends
2395
# ----------------------------------------------------------------
2496
function TensorOperations.tensorcontract!(
@@ -38,10 +110,13 @@ function TensorOperations.tensorcontract!(
38110
bipermA = _blockedpermutation(pA)
39111
bipermB = _blockedpermutation(pB)
40112
ipAB = invperm((pAB[1]..., pAB[2]...))
41-
bipermAB = _blockedpermutation((TupleTools.getindices(ipAB, trivtuple(length(pA[1]))),
42-
TupleTools.getindices(ipAB, trivtuple(length(pB[2])) .+ length(pA[1]))))
113+
bipermAB = _blockedpermutation((
114+
TupleTools.getindices(ipAB, trivtuple(length(pA[1]))),
115+
TupleTools.getindices(ipAB, trivtuple(length(pB[2])) .+ length(pA[1])),
116+
))
43117
A′ = conjA ? conj(A) : A
44118
B′ = conjB ? conj(B) : B
45119
return TensorAlgebra.contract!(backend, C, bipermAB, A′, bipermA, B′, bipermB, α, β)
46120
end
47121

122+
end

0 commit comments

Comments
 (0)