From 78553899dee059870ec14f081f850315942c05b7 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 27 Aug 2025 15:14:14 +0200 Subject: [PATCH 1/5] Add package extension --- Project.toml | 7 +++++++ ext/TensorAlgebraTensorOperationsExt.jl | 3 +++ 2 files changed, 10 insertions(+) create mode 100644 ext/TensorAlgebraTensorOperationsExt.jl diff --git a/Project.toml b/Project.toml index 1c77028..647b015 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,13 @@ EllipsisNotation = "1.8.0" LinearAlgebra = "1.10" MatrixAlgebraKit = "0.2" TensorProducts = "0.1.5" +TensorOperations = "5" TupleTools = "1.6.0" TypeParameterAccessors = "0.2.1, 0.3, 0.4" julia = "1.10" + +[weakdeps] +TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" + +[extensions] +TensorAlgebraTensorOperationsExt = "TensorOperations" diff --git a/ext/TensorAlgebraTensorOperationsExt.jl b/ext/TensorAlgebraTensorOperationsExt.jl new file mode 100644 index 0000000..2bb6c92 --- /dev/null +++ b/ext/TensorAlgebraTensorOperationsExt.jl @@ -0,0 +1,3 @@ +module TensorAlgebraTensorOperationsExt + +end From 5cfaaf29fbf3225cc990c8e70d9dd56a724e0637 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 27 Aug 2025 15:15:09 +0200 Subject: [PATCH 2/5] TensorAlgebra for TensorOperations --- ext/TensorAlgebraTensorOperationsExt.jl | 44 +++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/ext/TensorAlgebraTensorOperationsExt.jl b/ext/TensorAlgebraTensorOperationsExt.jl index 2bb6c92..e56d805 100644 --- a/ext/TensorAlgebraTensorOperationsExt.jl +++ b/ext/TensorAlgebraTensorOperationsExt.jl @@ -1,3 +1,47 @@ module TensorAlgebraTensorOperationsExt +using TensorAlgebra: TensorAlgebra, BlockedPermutation, Algorithm +using TupleTools +using TensorOperations +using TensorOperations: AbstractBackend as TOAlgorithm + +TensorAlgebra.Algorithm(backend::TOAlgorithm) = backend + +trivtuple(n) = ntuple(identity, n) + +function _index2tuple(p::BlockedPermutation{2}) + N₁, N₂ = blocklengths(p) + return ( + TupleTools.getindices(Tuple(p), trivtuple(N₁)), + TupleTools.getindices(Tuple(p), N₁ .+ trivtuple(N₂)), + ) end + +_blockedpermutation(p::Index2Tuple) = TensorAlgebra.blockedpermvcat(p...) + +# Using TensorAlgebra implementations as TensorOperations backends +# ---------------------------------------------------------------- +function TensorOperations.tensorcontract!( + C::AbstractArray, + A::AbstractArray, + pA::Index2Tuple, + conjA::Bool, + B::AbstractArray, + pB::Index2Tuple, + conjB::Bool, + pAB::Index2Tuple, + α::Number, + β::Number, + backend::Algorithm, + allocator, +) + bipermA = _blockedpermutation(pA) + bipermB = _blockedpermutation(pB) + ipAB = invperm((pAB[1]..., pAB[2]...)) + bipermAB = _blockedpermutation((TupleTools.getindices(ipAB, trivtuple(length(pA[1]))), + TupleTools.getindices(ipAB, trivtuple(length(pB[2])) .+ length(pA[1])))) + A′ = conjA ? conj(A) : A + B′ = conjB ? conj(B) : B + return TensorAlgebra.contract!(backend, C, bipermAB, A′, bipermA, B′, bipermB, α, β) +end + From 3ee9224392d38956ab94efd379baadaf4dafb31e Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 27 Aug 2025 15:22:24 +0200 Subject: [PATCH 3/5] TensorOperations for TensorAlgebra --- ext/TensorAlgebraTensorOperationsExt.jl | 79 ++++++++++++++++++++++++- 1 file changed, 77 insertions(+), 2 deletions(-) diff --git a/ext/TensorAlgebraTensorOperationsExt.jl b/ext/TensorAlgebraTensorOperationsExt.jl index e56d805..f677c14 100644 --- a/ext/TensorAlgebraTensorOperationsExt.jl +++ b/ext/TensorAlgebraTensorOperationsExt.jl @@ -19,6 +19,78 @@ end _blockedpermutation(p::Index2Tuple) = TensorAlgebra.blockedpermvcat(p...) +# Using TensorOperations backends as TensorAlgebra implementations +# ---------------------------------------------------------------- + +# not in-place +function TensorAlgebra.contract( + backend::TOAlgorithm, + bipermAB::BlockedPermutation, + A::AbstractArray, + bipermA::BlockedPermutation, + B::AbstractArray, + bipermB::BlockedPermutation, + α::Number, +) + pA = _index2tuple(bipermA) + pB = _index2tuple(bipermB) + + # TODO: this assumes biperm of output because not enough information! + ipermAB = invperm(Tuple(bipermAB)) + pAB = (TupleTools.getindices(ipermAB, length(ipermAB)), ()) + + return tensorcontract(A, pA, false, B, pB, false, pAB, α, backend) +end + +function TensorAlgebra.contract( + backend::TOAlgorithm, + labelsC, + A::AbstractArray, + labelsA, + B::AbstractArray, + labelsB, + α::Number, +) + return tensorcontract(labelsC, A, labelsA, B, labelsB, α; backend) +end + +# in-place +function TensorAlgebra.contract!( + backend::TOAlgorithm, + C::AbstractArray, + bipermAB::BlockedPermutation, + A::AbstractArray, + bipermA::BlockedPermutation, + B::AbstractArray, + bipermB::BlockedPermutation, + α::Number, + β::Number, +) + pA = _index2tuple(bipermA) + pB = _index2tuple(bipermB) + + # TODO: this assumes biperm of output because not enough information! + ipermAB = invperm(Tuple(bipermAB)) + pAB = (TupleTools.getindices(ipermAB, length(ipermAB)), ()) + + return tensorcontract!(C, A, pA, false, B, pB, false, pAB, α, β, backend) +end + +function TensorAlgebra.contract!( + backend::TOAlgorithm, + C::AbstractArray, + labelsC, + A::AbstractArray, + labelsA, + B::AbstractArray, + labelsB, + α::Number, + β::Number, +) + pA, pB, pAB = TensorOperations.contract_indices(labelsA, labelsB, labelsC) + return TensorOperations.tensorcontract!(C, A, pA, false, B, pB, false, pAB, α, β, backend) +end + # Using TensorAlgebra implementations as TensorOperations backends # ---------------------------------------------------------------- function TensorOperations.tensorcontract!( @@ -38,10 +110,13 @@ function TensorOperations.tensorcontract!( bipermA = _blockedpermutation(pA) bipermB = _blockedpermutation(pB) ipAB = invperm((pAB[1]..., pAB[2]...)) - bipermAB = _blockedpermutation((TupleTools.getindices(ipAB, trivtuple(length(pA[1]))), - TupleTools.getindices(ipAB, trivtuple(length(pB[2])) .+ length(pA[1])))) + bipermAB = _blockedpermutation(( + TupleTools.getindices(ipAB, trivtuple(length(pA[1]))), + TupleTools.getindices(ipAB, trivtuple(length(pB[2])) .+ length(pA[1])), + )) A′ = conjA ? conj(A) : A B′ = conjB ? conj(B) : B return TensorAlgebra.contract!(backend, C, bipermAB, A′, bipermA, B′, bipermB, α, β) end +end From 3c06f1e3477e7a81e7cede297d1d99e7eb902496 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 27 Aug 2025 15:22:28 +0200 Subject: [PATCH 4/5] Add tests --- test/test_tensoroperations.jl | 41 +++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 test/test_tensoroperations.jl diff --git a/test/test_tensoroperations.jl b/test/test_tensoroperations.jl new file mode 100644 index 0000000..067ef1e --- /dev/null +++ b/test/test_tensoroperations.jl @@ -0,0 +1,41 @@ +using Test: @test, @testset +using TensorOperations: @tensor, ncon +using TensorAlgebra: Matricize + +elts = (Float32, Float64, ComplexF32, ComplexF64) + +@testset "tensor network examples ($T)" for T in elts + D1, D2, D3 = 30, 40, 20 + d1, d2 = 2, 3 + A1 = rand(T, D1, d1, D2) .- 1//2 + A2 = rand(T, D2, d2, D3) .- 1//2 + rhoL = rand(T, D1, D1) .- 1//2 + rhoR = rand(T, D3, D3) .- 1//2 + H = rand(T, d1, d2, d1, d2) .- 1//2 + + @tensor HrA12[a, s1, s2, c] := + rhoL[a, a'] * A1[a', t1, b] * A2[b, t2, c'] * rhoR[c', c] * H[s1, s2, t1, t2] + @tensor backend = Matricize() HrA12′[a, s1, s2, c] := + rhoL[a, a'] * A1[a', t1, b] * A2[b, t2, c'] * rhoR[c', c] * H[s1, s2, t1, t2] + + @test HrA12 ≈ HrA12′ + @test HrA12 ≈ ncon( + [rhoL, H, A2, rhoR, A1], + [[-1, 1], [-2, -3, 4, 5], [2, 5, 3], [3, -4], [1, 4, 2]]; + backend=Matricize(), + ) + E = @tensor rhoL[a', a] * + A1[a, s, b] * + A2[b, s', c] * + rhoR[c, c'] * + H[t, t', s, s'] * + conj(A1[a', t, b']) * + conj(A2[b', t', c']) + @test E ≈ @tensor backend = Matricize() rhoL[a', a] * + A1[a, s, b] * + A2[b, s', c] * + rhoR[c, c'] * + H[t, t', s, s'] * + conj(A1[a', t, b']) * + conj(A2[b', t', c']) +end From fa81eca26f77718ecf77f08a5470ed0ff020d924 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 27 Aug 2025 15:47:02 +0200 Subject: [PATCH 5/5] Bump v0.3.12 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 647b015..263a613 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.3.11" +version = "0.3.12" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"