|
1 | 1 | module TensorAlgebraTensorOperationsExt |
2 | 2 |
|
| 3 | +using TensorAlgebra: TensorAlgebra, BlockedPermutation, Algorithm |
| 4 | +using TupleTools |
| 5 | +using TensorOperations |
| 6 | +using TensorOperations: AbstractBackend as TOAlgorithm |
| 7 | + |
| 8 | +TensorAlgebra.Algorithm(backend::TOAlgorithm) = backend |
| 9 | + |
| 10 | +trivtuple(n) = ntuple(identity, n) |
| 11 | + |
| 12 | +function _index2tuple(p::BlockedPermutation{2}) |
| 13 | + N₁, N₂ = blocklengths(p) |
| 14 | + return ( |
| 15 | + TupleTools.getindices(Tuple(p), trivtuple(N₁)), |
| 16 | + TupleTools.getindices(Tuple(p), N₁ .+ trivtuple(N₂)), |
| 17 | + ) |
3 | 18 | end |
| 19 | + |
| 20 | +_blockedpermutation(p::Index2Tuple) = TensorAlgebra.blockedpermvcat(p...) |
| 21 | + |
| 22 | +# Using TensorAlgebra implementations as TensorOperations backends |
| 23 | +# ---------------------------------------------------------------- |
| 24 | +function TensorOperations.tensorcontract!( |
| 25 | + C::AbstractArray, |
| 26 | + A::AbstractArray, |
| 27 | + pA::Index2Tuple, |
| 28 | + conjA::Bool, |
| 29 | + B::AbstractArray, |
| 30 | + pB::Index2Tuple, |
| 31 | + conjB::Bool, |
| 32 | + pAB::Index2Tuple, |
| 33 | + α::Number, |
| 34 | + β::Number, |
| 35 | + backend::Algorithm, |
| 36 | + allocator, |
| 37 | +) |
| 38 | + bipermA = _blockedpermutation(pA) |
| 39 | + bipermB = _blockedpermutation(pB) |
| 40 | + ipAB = invperm((pAB[1]..., pAB[2]...)) |
| 41 | + bipermAB = _blockedpermutation((TupleTools.getindices(ipAB, trivtuple(length(pA[1]))), |
| 42 | + TupleTools.getindices(ipAB, trivtuple(length(pB[2])) .+ length(pA[1])))) |
| 43 | + A′ = conjA ? conj(A) : A |
| 44 | + B′ = conjB ? conj(B) : B |
| 45 | + return TensorAlgebra.contract!(backend, C, bipermAB, A′, bipermA, B′, bipermB, α, β) |
| 46 | +end |
| 47 | + |
0 commit comments