Skip to content

Commit b663fbf

Browse files
committed
TensorOperations for TensorAlgebra
1 parent 7078b4d commit b663fbf

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed

ext/TensorAlgebraTensorOperationsExt.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,74 @@ end
1919

2020
_blockedpermutation(p::Index2Tuple) = TensorAlgebra.blockedpermvcat(p...)
2121

22+
# Using TensorOperations backends as TensorAlgebra implementations
23+
# ----------------------------------------------------------------
24+
25+
# not in-place
26+
function TensorAlgebra.contract(
27+
backend::TOAlgorithm,
28+
bipermAB::BlockedPermutation,
29+
A::AbstractArray,
30+
bipermA::BlockedPermutation,
31+
B::AbstractArray,
32+
bipermB::BlockedPermutation,
33+
α::Number,
34+
)
35+
pA = _index2tuple(bipermA)
36+
pB = _index2tuple(bipermB)
37+
38+
# TODO: this assumes biperm of output because not enough information!
39+
ipermAB = invperm(Tuple(bipermAB))
40+
pAB = (TupleTools.getindices(ipermAB, length(ipermAB)), ())
41+
42+
return tensorcontract(A, pA, false, B, pB, false, pAB, α, backend)
43+
end
44+
45+
function TensorAlgebra.contract(
46+
backend::TOAlgorithm,
47+
labelsC,
48+
A::AbstractArray,
49+
labelsA,
50+
B::AbstractArray,
51+
labelsB,
52+
α::Number,
53+
)
54+
return tensorcontract(labelsC, A, labelsA, B, labelsB, α; backend)
55+
end
56+
57+
# in-place
58+
function TensorAlgebra.contract!(
59+
backend::TOAlgorithm,
60+
C::AbstractArray,
61+
bipermAB::BlockedPermutation,
62+
A::AbstractArray,
63+
bipermA::BlockedPermutation,
64+
B::AbstractArray,
65+
bipermB::BlockedPermutation,
66+
α::Number,
67+
β::Number,
68+
)
69+
pA = _index2tuple(bipermA)
70+
pB = _index2tuple(bipermB)
71+
pAB = _index2tuple(bipermAB)
72+
return tensorcontract!(C, A, pA, false, B, pB, false, pAB, α, β, backend)
73+
end
74+
75+
function TensorAlgebra.contract!(
76+
backend::TOAlgorithm,
77+
C::AbstractArray,
78+
labelsC,
79+
A::AbstractArray,
80+
labelsA,
81+
B::AbstractArray,
82+
labelsB,
83+
α::Number,
84+
β::Number,
85+
)
86+
pA, pB, pAB = TensorOperations.contract_indices(labelsA, labelsB, labelsC)
87+
return TensorOperations.tensorcontract!(C, A, pA, false, B, pB, false, pAB, α, β, backend)
88+
end
89+
2290
# Using TensorAlgebra implementations as TensorOperations backends
2391
# ----------------------------------------------------------------
2492
function TensorOperations.tensorcontract!(
@@ -43,3 +111,4 @@ function TensorOperations.tensorcontract!(
43111
return TensorAlgebra.contract!(backend, C, bipermAB, A′, bipermA, B′, bipermB, α, β)
44112
end
45113

114+
end

0 commit comments

Comments
 (0)