diff --git a/Project.toml b/Project.toml index 1c77028..263a613 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.3.11" +version = "0.3.12" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" @@ -20,6 +20,13 @@ EllipsisNotation = "1.8.0" LinearAlgebra = "1.10" MatrixAlgebraKit = "0.2" TensorProducts = "0.1.5" +TensorOperations = "5" TupleTools = "1.6.0" TypeParameterAccessors = "0.2.1, 0.3, 0.4" julia = "1.10" + +[weakdeps] +TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" + +[extensions] +TensorAlgebraTensorOperationsExt = "TensorOperations" diff --git a/ext/TensorAlgebraTensorOperationsExt.jl b/ext/TensorAlgebraTensorOperationsExt.jl new file mode 100644 index 0000000..f677c14 --- /dev/null +++ b/ext/TensorAlgebraTensorOperationsExt.jl @@ -0,0 +1,122 @@ +module TensorAlgebraTensorOperationsExt + +using TensorAlgebra: TensorAlgebra, BlockedPermutation, Algorithm +using TupleTools +using TensorOperations +using TensorOperations: AbstractBackend as TOAlgorithm + +TensorAlgebra.Algorithm(backend::TOAlgorithm) = backend + +trivtuple(n) = ntuple(identity, n) + +function _index2tuple(p::BlockedPermutation{2}) + N₁, N₂ = blocklengths(p) + return ( + TupleTools.getindices(Tuple(p), trivtuple(N₁)), + TupleTools.getindices(Tuple(p), N₁ .+ trivtuple(N₂)), + ) +end + +_blockedpermutation(p::Index2Tuple) = TensorAlgebra.blockedpermvcat(p...) + +# Using TensorOperations backends as TensorAlgebra implementations +# ---------------------------------------------------------------- + +# not in-place +function TensorAlgebra.contract( + backend::TOAlgorithm, + bipermAB::BlockedPermutation, + A::AbstractArray, + bipermA::BlockedPermutation, + B::AbstractArray, + bipermB::BlockedPermutation, + α::Number, +) + pA = _index2tuple(bipermA) + pB = _index2tuple(bipermB) + + # TODO: this assumes biperm of output because not enough information! + ipermAB = invperm(Tuple(bipermAB)) + pAB = (TupleTools.getindices(ipermAB, length(ipermAB)), ()) + + return tensorcontract(A, pA, false, B, pB, false, pAB, α, backend) +end + +function TensorAlgebra.contract( + backend::TOAlgorithm, + labelsC, + A::AbstractArray, + labelsA, + B::AbstractArray, + labelsB, + α::Number, +) + return tensorcontract(labelsC, A, labelsA, B, labelsB, α; backend) +end + +# in-place +function TensorAlgebra.contract!( + backend::TOAlgorithm, + C::AbstractArray, + bipermAB::BlockedPermutation, + A::AbstractArray, + bipermA::BlockedPermutation, + B::AbstractArray, + bipermB::BlockedPermutation, + α::Number, + β::Number, +) + pA = _index2tuple(bipermA) + pB = _index2tuple(bipermB) + + # TODO: this assumes biperm of output because not enough information! + ipermAB = invperm(Tuple(bipermAB)) + pAB = (TupleTools.getindices(ipermAB, length(ipermAB)), ()) + + return tensorcontract!(C, A, pA, false, B, pB, false, pAB, α, β, backend) +end + +function TensorAlgebra.contract!( + backend::TOAlgorithm, + C::AbstractArray, + labelsC, + A::AbstractArray, + labelsA, + B::AbstractArray, + labelsB, + α::Number, + β::Number, +) + pA, pB, pAB = TensorOperations.contract_indices(labelsA, labelsB, labelsC) + return TensorOperations.tensorcontract!(C, A, pA, false, B, pB, false, pAB, α, β, backend) +end + +# Using TensorAlgebra implementations as TensorOperations backends +# ---------------------------------------------------------------- +function TensorOperations.tensorcontract!( + C::AbstractArray, + A::AbstractArray, + pA::Index2Tuple, + conjA::Bool, + B::AbstractArray, + pB::Index2Tuple, + conjB::Bool, + pAB::Index2Tuple, + α::Number, + β::Number, + backend::Algorithm, + allocator, +) + bipermA = _blockedpermutation(pA) + bipermB = _blockedpermutation(pB) + ipAB = invperm((pAB[1]..., pAB[2]...)) + bipermAB = _blockedpermutation(( + TupleTools.getindices(ipAB, trivtuple(length(pA[1]))), + TupleTools.getindices(ipAB, trivtuple(length(pB[2])) .+ length(pA[1])), + )) + A′ = conjA ? conj(A) : A + B′ = conjB ? conj(B) : B + return TensorAlgebra.contract!(backend, C, bipermAB, A′, bipermA, B′, bipermB, α, β) +end + +end diff --git a/test/test_tensoroperations.jl b/test/test_tensoroperations.jl new file mode 100644 index 0000000..067ef1e --- /dev/null +++ b/test/test_tensoroperations.jl @@ -0,0 +1,41 @@ +using Test: @test, @testset +using TensorOperations: @tensor, ncon +using TensorAlgebra: Matricize + +elts = (Float32, Float64, ComplexF32, ComplexF64) + +@testset "tensor network examples ($T)" for T in elts + D1, D2, D3 = 30, 40, 20 + d1, d2 = 2, 3 + A1 = rand(T, D1, d1, D2) .- 1//2 + A2 = rand(T, D2, d2, D3) .- 1//2 + rhoL = rand(T, D1, D1) .- 1//2 + rhoR = rand(T, D3, D3) .- 1//2 + H = rand(T, d1, d2, d1, d2) .- 1//2 + + @tensor HrA12[a, s1, s2, c] := + rhoL[a, a'] * A1[a', t1, b] * A2[b, t2, c'] * rhoR[c', c] * H[s1, s2, t1, t2] + @tensor backend = Matricize() HrA12′[a, s1, s2, c] := + rhoL[a, a'] * A1[a', t1, b] * A2[b, t2, c'] * rhoR[c', c] * H[s1, s2, t1, t2] + + @test HrA12 ≈ HrA12′ + @test HrA12 ≈ ncon( + [rhoL, H, A2, rhoR, A1], + [[-1, 1], [-2, -3, 4, 5], [2, 5, 3], [3, -4], [1, 4, 2]]; + backend=Matricize(), + ) + E = @tensor rhoL[a', a] * + A1[a, s, b] * + A2[b, s', c] * + rhoR[c, c'] * + H[t, t', s, s'] * + conj(A1[a', t, b']) * + conj(A2[b', t', c']) + @test E ≈ @tensor backend = Matricize() rhoL[a', a] * + A1[a, s, b] * + A2[b, s', c] * + rhoR[c, c'] * + H[t, t', s, s'] * + conj(A1[a', t, b']) * + conj(A2[b', t', c']) +end