From d32b1c8271c89f61ac99f4c629b8854bf6a99d1e Mon Sep 17 00:00:00 2001 From: lkdvos Date: Sun, 21 Jan 2024 10:21:18 +0100 Subject: [PATCH 01/29] Add rrule planaradd! --- ext/TensorKitChainRulesCoreExt.jl | 42 +++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index b36976fda..deb5faf3e 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -629,6 +629,48 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, return ΔA end +# Planar rrules +# -------------- +function ChainRulesCore.rrule(::typeof(TensorKit.planaradd!), C::AbstractTensorMap{S,N₁,N₂}, + A::AbstractTensorMap{S}, p::Index2Tuple{N₁,N₂}, + α::Number, β::Number, + backend::Backend...) where {S,N₁,N₂} + C′ = planaradd!(copy(C), A, p, α, β, backend...) + + projectA = ProjectTo(A) + projectC = ProjectTo(C) + projectα = ProjectTo(α) + projectβ = ProjectTo(β) + + function planaradd_pullback(ΔC′) + ΔC = unthunk(ΔC′) + + dC = @thunk projectC(scale(ΔC, conj(β))) + dA = @thunk begin + ip = TensorKit._canonicalize(invperm(linearize(p)), A) + _dA = zerovector(A, VectorInterface.promote_add(ΔC, α)) + _dA = planaradd!(_dA, ip, ΔC, conj(α), Zero(), backend...) + return projectA(_dA) + end + dα = @thunk begin + _dα = tensorscalar(planarcontract(A, ((), linearize(p)), + ΔC, (trivtuple(numind(p)), ()), + ((), ()), One(), Zero(), backend...)) + return projectα(_dα) + end + dβ = @thunk begin + _dβ = tensorscalar(planarcontract(C, ((), trivtuple(numind(pC))), + ΔC, (trivtuple(numind(pC)), ()), + ((), ()), One(), Zero(), backend...)) + return projectβ(_dβ) + end + dbackend = map(x -> NoTangent(), backend) + return NoTangent(), dC, dA, NoTangent(), dα, dβ, dbackend... + end + + return C′, planaradd_pullback +end + # Convert rrules #---------------- function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{Dict}, t::AbstractTensorMap) From 276f3bb4fcce4c87c5450e3c58dc14f92c2c9c10 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Sun, 21 Jan 2024 10:46:28 +0100 Subject: [PATCH 02/29] Add rrule `planarcontract` --- ext/TensorKitChainRulesCoreExt.jl | 53 +++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index deb5faf3e..63d73a77a 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -671,6 +671,59 @@ function ChainRulesCore.rrule(::typeof(TensorKit.planaradd!), C::AbstractTensorM return C′, planaradd_pullback end +function ChainRulesCore.rrule(::typeof(TensorKit.planarcontract!), + C::AbstractTensorMap{S,N₁,N₂}, + A::AbstractTensorMap{S}, pA::Index2Tuple, + B::AbstractTensorMap{S}, pB::Index2Tuple, + pAB::Index2Tuple{N₁,N₂}, + α::Number, β::Number, backend::Backend...) where {S,N₁,N₂} + C′ = planarcontract!(copy(C), A, pA, B, pB, pAB, α, β, backend...) + + projectA = ProjectTo(A) + projectB = ProjectTo(B) + projectC = ProjectTo(C) + projectα = ProjectTo(α) + projectβ = ProjectTo(β) + + function planarcontract_pullback(ΔC′) + ΔC = unthunk(ΔC′) + pΔC = TensorKit._canonicalize(invperm(linearize(pAB)), ΔC) + dC = @thunk projectC(scale(ΔC, conj(β))) + dA = @thunk begin + ipA = TensorKit._canonicalize(invperm(linearize(pA)), (), A) + _dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B), typeof(α))) + _dA = planarcontract!(_dA, ΔC, pΔC, adjoint(B), reverse(pB), ipA, + conj(α), Zero(), backend...) + return projectA(_dA) + end + dB = @thunk begin + ipB = TensorKit._canonicalize((invperm(linearize(pB)), ()), B) + _dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A), typeof(α))) + _dB = planarcontract!(_dB, A', reverse(pA), ΔC, pΔC, ipB, + conj(α), Zero(), backend...) + return projectB(_dB) + end + dα = @thunk begin + AB = planarcontract!(similar(C), A, pA, B, pB, pAB, One(), Zero(), backend...) + _dα = tensorscalar(planarcontract(AB', ((), trivtuple(numind(pAB))), + ΔC, (trivtuple(numind(pAB)), ()), ((), ()), + One(), Zero(), backend...)) + return projectα(_dα) + end + dβ = @thunk begin + _dβ = tensorscalar(planarcontract(C', ((), trivtuple(numind(pAB))), + ΔC, (trivtuple(numind(pAB)), ()), ((), ()), + One(), Zero(), backend...)) + return projectβ(_dβ) + end + dbackend = map(x -> NoTangent(), backend) + return NoTangent(), dC, dA, NoTangent(), dB, NoTangent(), NoTangent(), + dα, dβ, dbackend... + end + + return C′, planarcontract_pullback +end + # Convert rrules #---------------- function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{Dict}, t::AbstractTensorMap) From 93db86187145a0bbae4d4e0678895b077cbb291f Mon Sep 17 00:00:00 2001 From: lkdvos Date: Sun, 21 Jan 2024 10:52:20 +0100 Subject: [PATCH 03/29] Add `planarcontract` (without `!`) --- src/planar/planaroperations.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/planar/planaroperations.jl b/src/planar/planaroperations.jl index ee23bcd74..35ef92181 100644 --- a/src/planar/planaroperations.jl +++ b/src/planar/planaroperations.jl @@ -84,6 +84,15 @@ function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, return C end +function planarcontract(A::AbstractTensorMap{S}, pA::Index2Tuple, + B::AbstractTensorMap{S}, pB::Index2Tuple, + pAB::Index2Tuple{N₁,N₂}, + α::Number, backend::Backend...) where {S,N₁,N₂} + TC = promote_contract(scalartype(A), scalartype(B), scalartype(α)) + C = tensoralloc_contract(TC, pC, A, pA, :N, B, pB, :N) + return planarcontract!(C, A, pA, B, pB, pAB, α, β, backend...) +end + # auxiliary routines _cyclicpermute(t::Tuple) = (Base.tail(t)..., t[1]) _cyclicpermute(t::Tuple{}) = () From d84ffb41e2d294d5c8931ff582fdba9bf47e8dbc Mon Sep 17 00:00:00 2001 From: lkdvos Date: Sun, 21 Jan 2024 10:54:04 +0100 Subject: [PATCH 04/29] Fix TensorKitChainRulesCoreExt imports --- ext/TensorKitChainRulesCoreExt.jl | 11 ++++++----- src/planar/planaroperations.jl | 6 +++--- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index 63d73a77a..a3f02dff0 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -2,6 +2,7 @@ module TensorKitChainRulesCoreExt using TensorOperations using TensorKit +using TensorKit: planaradd!, planarcontract!, planarcontract, _canonicalize using ChainRulesCore using LinearAlgebra using TupleTools @@ -134,7 +135,7 @@ end function ChainRulesCore.rrule(::typeof(permute), tsrc::AbstractTensorMap, p::Index2Tuple; copy::Bool=false) function permute_pullback(Δtdst) - invp = TensorKit._canonicalize(TupleTools.invperm(linearize(p)), tsrc) + invp = _canonicalize(TupleTools.invperm(linearize(p)), tsrc) return NoTangent(), permute(unthunk(Δtdst), invp; copy=true), NoTangent() end return permute(tsrc, p; copy=true), permute_pullback @@ -647,7 +648,7 @@ function ChainRulesCore.rrule(::typeof(TensorKit.planaradd!), C::AbstractTensorM dC = @thunk projectC(scale(ΔC, conj(β))) dA = @thunk begin - ip = TensorKit._canonicalize(invperm(linearize(p)), A) + ip = _canonicalize(invperm(linearize(p)), A) _dA = zerovector(A, VectorInterface.promote_add(ΔC, α)) _dA = planaradd!(_dA, ip, ΔC, conj(α), Zero(), backend...) return projectA(_dA) @@ -687,17 +688,17 @@ function ChainRulesCore.rrule(::typeof(TensorKit.planarcontract!), function planarcontract_pullback(ΔC′) ΔC = unthunk(ΔC′) - pΔC = TensorKit._canonicalize(invperm(linearize(pAB)), ΔC) + pΔC = _canonicalize(invperm(linearize(pAB)), ΔC) dC = @thunk projectC(scale(ΔC, conj(β))) dA = @thunk begin - ipA = TensorKit._canonicalize(invperm(linearize(pA)), (), A) + ipA = _canonicalize(invperm(linearize(pA)), (), A) _dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B), typeof(α))) _dA = planarcontract!(_dA, ΔC, pΔC, adjoint(B), reverse(pB), ipA, conj(α), Zero(), backend...) return projectA(_dA) end dB = @thunk begin - ipB = TensorKit._canonicalize((invperm(linearize(pB)), ()), B) + ipB = _canonicalize((invperm(linearize(pB)), ()), B) _dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A), typeof(α))) _dB = planarcontract!(_dB, A', reverse(pA), ΔC, pΔC, ipB, conj(α), Zero(), backend...) diff --git a/src/planar/planaroperations.jl b/src/planar/planaroperations.jl index 35ef92181..b57be6729 100644 --- a/src/planar/planaroperations.jl +++ b/src/planar/planaroperations.jl @@ -88,9 +88,9 @@ function planarcontract(A::AbstractTensorMap{S}, pA::Index2Tuple, B::AbstractTensorMap{S}, pB::Index2Tuple, pAB::Index2Tuple{N₁,N₂}, α::Number, backend::Backend...) where {S,N₁,N₂} - TC = promote_contract(scalartype(A), scalartype(B), scalartype(α)) - C = tensoralloc_contract(TC, pC, A, pA, :N, B, pB, :N) - return planarcontract!(C, A, pA, B, pB, pAB, α, β, backend...) + TC = TensorOperations.promote_contract(scalartype(A), scalartype(B), scalartype(α)) + C = TensorOperations.tensoralloc_contract(TC, pAB, A, pA, :N, B, pB, :N) + return planarcontract!(C, A, pA, B, pB, pAB, α, VectorInterface.Zero(), backend...) end # auxiliary routines From 974173d01ca02b30f5c1f0ffddf7d145398aa5ec Mon Sep 17 00:00:00 2001 From: lkdvos Date: Mon, 29 Jan 2024 19:00:00 +0100 Subject: [PATCH 05/29] painful `reorder_indices` rewrite --- src/planar/planaroperations.jl | 181 ++++++++++++++++++++------------- test/planar.jl | 179 ++++++++++++++++++++++++++++++++ 2 files changed, 291 insertions(+), 69 deletions(-) diff --git a/src/planar/planaroperations.jl b/src/planar/planaroperations.jl index b57be6729..6b28a9980 100644 --- a/src/planar/planaroperations.jl +++ b/src/planar/planaroperations.jl @@ -44,42 +44,52 @@ function planartrace!(C::AbstractTensorMap{S,N₁,N₂}, return C end -function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, +function planarcontract!(C::AbstractTensorMap{S}, A::AbstractTensorMap{S}, - pA::Index2Tuple, + pA::Index2Tuple{N₁,N₃}, B::AbstractTensorMap{S}, - pB::Index2Tuple, - pAB::Index2Tuple{N₁,N₂}, + pB::Index2Tuple{N₃,N₂}, + pAB::Index2Tuple, α::Number, β::Number, - backend::Backend...) where {S,N₁,N₂} + backend::Backend...) where {S,N₁,N₂,N₃} if BraidingStyle(sectortype(S)) == Bosonic() return contract!(C, A, pA, B, pB, pAB, α, β, backend...) end - codA, domA = codomainind(A), domainind(A) - codB, domB = codomainind(B), domainind(B) - oindA, cindA = pA - cindB, oindB = pB - oindA, cindA, oindB, cindB = reorder_indices(codA, domA, codB, domB, oindA, cindA, - oindB, cindB, pAB...) + indA = (codomainind(A), reverse(domainind(A))) + indB = (codomainind(B), reverse(domainind(B))) + pA′, pB′, pAB′ = reorder_indices(indA, pA, indB, pB, pAB) - if oindA == codA && cindA == domA + + if pA′ == (codomainind(A), domainind(A)) A′ = A else - A′ = TO.tensoralloc_add(scalartype(A), (oindA, cindA), A, :N, true) - add_transpose!(A′, A, (oindA, cindA), true, false, backend...) + A′ = TO.tensoralloc_add(scalartype(A), pA′, A, :N, true) + add_transpose!(A′, A, pA′, true, false, backend...) end - if cindB == codB && oindB == domB + if pB′ == (codomainind(B), domainind(B)) B′ = B else - B′ = TensorOperations.tensoralloc_add(scalartype(B), (cindB, oindB), B, :N, true) - add_transpose!(B′, B, (cindB, oindB), true, false, backend...) + B′ = TO.tensoralloc_add(scalartype(B), pB′, B, :N, true) + add_transpose!(B′, B, pB′, true, false, backend...) + end + + ipAB = TupleTools.invperm(linearize(pAB′)) + oindAinC = TupleTools.getindices(ipAB, ntuple(n -> n, N₁)) + oindBinC = TupleTools.getindices(ipAB, ntuple(n -> n + N₁, N₂)) + + if has_shared_permute(C, (oindAinC, oindBinC)) + C′ = transpose(C, (oindAinC, oindBinC)) + mul!(C′, A′, B′, α, β) + else + C′ = A′ * B′ + add_transpose!(C, C′, pAB′, α, β) end - mul!(C, A′, B′, α, β) - (oindA == codA && cindA == domA) || TO.tensorfree!(A′) - (cindB == codB && oindB == domB) || TO.tensorfree!(B′) + + pA′ == (codomainind(A), domainind(A)) || TO.tensorfree!(A′) + pB′ == (codomainind(B), domainind(B)) || TO.tensorfree!(B′) return C end @@ -88,8 +98,8 @@ function planarcontract(A::AbstractTensorMap{S}, pA::Index2Tuple, B::AbstractTensorMap{S}, pB::Index2Tuple, pAB::Index2Tuple{N₁,N₂}, α::Number, backend::Backend...) where {S,N₁,N₂} - TC = TensorOperations.promote_contract(scalartype(A), scalartype(B), scalartype(α)) - C = TensorOperations.tensoralloc_contract(TC, pAB, A, pA, :N, B, pB, :N) + TC = TO.promote_contract(scalartype(A), scalartype(B), scalartype(α)) + C = TO.tensoralloc_contract(TC, pAB, A, pA, :N, B, pB, :N) return planarcontract!(C, A, pA, B, pB, pAB, α, VectorInterface.Zero(), backend...) end @@ -97,56 +107,89 @@ end _cyclicpermute(t::Tuple) = (Base.tail(t)..., t[1]) _cyclicpermute(t::Tuple{}) = () -function reorder_indices(codA, domA, codB, domB, oindA, oindB, p1, p2) - N₁ = length(oindA) - N₂ = length(oindB) - @assert length(p1) == N₁ && all(in(p1), 1:N₁) - @assert length(p2) == N₂ && all(in(p2), N₁ .+ (1:N₂)) - oindA2 = TupleTools.getindices(oindA, p1) - oindB2 = TupleTools.getindices(oindB, p2 .- N₁) - indA = (codA..., reverse(domA)...) - indB = (codB..., reverse(domB)...) - # cycle indA to be of the form (oindA2..., reverse(cindA2)...) - while length(oindA2) > 0 && indA[1] != oindA2[1] - indA = _cyclicpermute(indA) - end - # cycle indB to be of the form (cindB2..., reverse(oindB2)...) - while length(oindB2) > 0 && indB[end] != oindB2[1] - indB = _cyclicpermute(indB) - end - for i in 2:N₁ - @assert indA[i] == oindA2[i] - end - for j in 2:N₂ - @assert indB[end + 1 - j] == oindB2[j] + +function _iscyclicpermutation(v1, v2) + length(v1) == length(v2) || return false + return iscyclicpermutation(_indexin(v1, v2)) +end + +function _findsetcircshift(p_cyclic, p_subset) + N = length(p_cyclic) + M = length(p_subset) + i = findfirst(0:(N - 1)) do i + return issetequal(TupleTools.getindices(p_cyclic, ntuple(n -> mod1(n + i, N), M)), + p_subset) end - Nc = length(indA) - N₁ - @assert Nc == length(indB) - N₂ - pc = ntuple(identity, Nc) - cindA2 = reverse(TupleTools.getindices(indA, N₁ .+ pc)) - cindB2 = TupleTools.getindices(indB, pc) - return oindA2, cindA2, oindB2, cindB2 + isnothing(i) && throw(ArgumentError("no cyclic permutation of $p_cyclic that matches $p_subset")) + return i-1::Int end -function reorder_indices(codA, domA, codB, domB, oindA, cindA, oindB, cindB, p1, p2) - oindA2, cindA2, oindB2, cindB2 = reorder_indices(codA, domA, codB, domB, oindA, oindB, - p1, p2) +function reorder_planar_indices(indA, pA, indB, pB, pAB) + NA₁ = length(pA[1]) + NA₂ = length(pA[2]) + NA = NA₁ + NA₂ + NB₁ = length(pB[1]) + NB₂ = length(pB[2]) + NB = NB₁ + NB₂ + NAB₁ = length(pAB[1]) + NAB₂ = length(pAB[2]) + NAB = NAB₁ + NAB₂ + + # input checks + @assert NA == length(indA[1]) + length(indA[2]) + @assert NB == length(indB[1]) + length(indB[2]) + @assert NA₂ == NB₁ + @assert NAB == NA₁ + NB₂ - #if oindA or oindB are empty, then reorder indices can only order it correctly up to a cyclic permutation! - if isempty(oindA2) && !isempty(cindA) - # isempty(cindA) is a cornercase which I'm not sure if we can encounter - hit = cindA[findfirst(==(first(cindB2)), cindB)] - while hit != first(cindA2) - cindA2 = _cyclicpermute(cindA2) - end - end - if isempty(oindB2) && !isempty(cindB) - hit = cindB[findfirst(==(first(cindA2)), cindA)] - while hit != first(cindB2) - cindB2 = _cyclicpermute(cindB2) - end + # find circshift index of pAB if considered as shifting sets + indAB = (ntuple(identity, NAB₁), + reverse(ntuple(n -> n + NAB₁, NAB₂))) + indAB_lin = (indAB[1]..., indAB[2]...) + iAB = _findsetcircshift(indAB_lin, pAB[1]) + @assert iAB == _findsetcircshift(indAB_lin, pAB[2]) - NAB₁ "sanity check" + + # migrate permutations from pAB to pA and pB + permA = TupleTools.getindices((pAB[1]..., reverse(pAB[2])...), + ntuple(n -> mod1(n + iAB, NAB), NA₁)) + permB = reverse(TupleTools.getindices((pAB[1]..., reverse(pAB[2])...), + ntuple(n -> mod1(n + iAB + NA₁, NAB), NB₂)) .- + NA₁) + + pA′ = (TupleTools.getindices(pA[1], permA), pA[2]) + pB′ = (pB[1], reverse(TupleTools.getindices(reverse(pB[2]), permB))) + pAB′ = (ntuple(n -> n + iAB, NAB₁), + ntuple(n -> n + iAB + NAB₁, NAB₂)) + + # fix permutations of contracted indices + if NA₂ > 0 + indA_lin = (indA[1]..., indA[2]...) + iA = _findsetcircshift(indA_lin, pA′[1]) + @assert iA == _findsetcircshift(indA_lin, pA′[2]) - NA₁ "sanity check" + pA′ = (pA′[1], + reverse(TupleTools.getindices(linearize(indA), + ntuple(n -> mod1(n + iA + NA₁, NA), NA₂)))) + + indB_lin = (indB[1]..., indB[2]...) + iB = _findsetcircshift(indB_lin, pB′[1]) + @assert iB == _findsetcircshift(indB_lin, pB′[2]) - NB₁ "sanity check" + pB′ = (TupleTools.getindices(linearize(indB), + ntuple(n -> mod1(n + iB, NB), NB₁)), + pB′[2]) end - @assert TupleTools.sort(cindA) == TupleTools.sort(cindA2) - @assert TupleTools.sort(tuple.(cindA2, cindB2)) == TupleTools.sort(tuple.(cindA, cindB)) - return oindA2, cindA2, oindB2, cindB2 + + # make sure this is still the same contraction + @assert issetequal(pA[1], pA′[1]) && issetequal(pA[2], pA′[2]) + @assert issetequal(pB[1], pB′[1]) && issetequal(pB[2], pB′[2]) + @assert issetequal(pAB[1], pAB′[1]) && issetequal(pAB[2], pAB′[2]) + @assert issetequal(tuple.(pA[2], pB[1]), tuple.(pA′[2], pB′[1])) + + # make sure that everything is now planar + @assert _iscyclicpermutation((indA[1]..., (indA[2])...), + (pA′[1]..., reverse(pA′[2])...)) + @assert _iscyclicpermutation((indB[1]..., (indB[2])...), + (pB′[1]..., reverse(pB′[2])...)) + @assert _iscyclicpermutation((indAB[1]..., (indAB[2])...), + (pAB′[1]..., reverse(pAB′[2])...)) + + return pA′, pB′, pAB′ end diff --git a/test/planar.jl b/test/planar.jl index d7edf1460..c1297a68a 100644 --- a/test/planar.jl +++ b/test/planar.jl @@ -27,6 +27,185 @@ function force_planar(tsrc::TensorMap{<:GradedSpace}) return tdst end +using TensorKit: reorder_planar_indices +@testset "reorder_indices" begin + @testset "trivial case" begin + pA = ((1, 2, 3), (4, 5)) + pB = ((1, 2), (3, 4, 5)) + pAB = ((1, 2, 3), (4, 5, 6)) + NA = (3, 2) + NB = (2, 3) + indA = (ntuple(identity, NA[1]), reverse(ntuple(identity, NA[2])) .+ NA[1]) + indB = (ntuple(identity, NB[1]), reverse(ntuple(identity, NB[2])) .+ NB[1]) + + pA′, pB′, pAB′ = reorder_planar_indices(indA, pA, indB, pB, pAB) + @test pA == pA′ + @test pB == pB′ + @test pAB == pAB′ + end + + @testset "trivial case" begin + pA = ((1, 2, 3), (4, 5, 6)) + pB = ((1, 2, 3), (4, 5, 6)) + pAB = ((1, 2, 3), (4, 5, 6)) + NA = (3, 3) + NB = (3, 3) + indA = (ntuple(identity, NA[1]), reverse(ntuple(identity, NA[2])) .+ NA[1]) + indB = (ntuple(identity, NB[1]), reverse(ntuple(identity, NB[2])) .+ NB[1]) + + pA′, pB′, pAB′ = reorder_planar_indices(indA, pA, indB, pB, pAB) + @test pA == pA′ + @test pB == pB′ + @test pAB == pAB′ + end + + @testset "swap outer indices" begin + pA = ((2, 3, 1), (4, 5)) + pB = ((1, 2), (3, 4, 5)) + pAB = ((3, 1, 2), (4, 5, 6)) + NA = (3, 2) + NB = (2, 3) + indA = (ntuple(identity, NA[1]), reverse(ntuple(identity, NA[2])) .+ NA[1]) + indB = (ntuple(identity, NB[1]), reverse(ntuple(identity, NB[2])) .+ NB[1]) + + pA′, pB′, pAB′ = reorder_planar_indices(indA, pA, indB, pB, pAB) + @test pA′ == ((1, 2, 3), (4, 5)) + @test pB′ == pB + @test pAB′ == ((1, 2, 3), (4, 5, 6)) + end + + @testset "swap contracted inds" begin + pA = ((1, 2, 3), (5, 4)) + pB = ((2, 1), (3, 4, 5)) + pAB = ((1, 2, 3), (4, 5, 6)) + NA = (3, 2) + NB = (2, 3) + indA = (ntuple(identity, NA[1]), reverse(ntuple(identity, NA[2])) .+ NA[1]) + indB = (ntuple(identity, NB[1]), reverse(ntuple(identity, NB[2])) .+ NB[1]) + + pA′, pB′, pAB′ = reorder_planar_indices(indA, pA, indB, pB, pAB) + @test pA[1] == pA′[1] + @test pA[2] == reverse(pA′[2]) + @test pB[1] == reverse(pB′[1]) + @test pB[2] == pB′[2] + @test pAB == pAB′ + end + + @testset "trivial case" begin + pA = ((1, 2, 3), (4, 5, 6)) + pB = ((1, 2, 3), (6, 5, 4)) + pAB = ((1, 2, 3), (6, 5, 4)) + NA = (3, 3) + NB = (3, 3) + indA = (ntuple(identity, NA[1]), reverse(ntuple(identity, NA[2])) .+ NA[1]) + indB = (ntuple(identity, NB[1]), reverse(ntuple(identity, NB[2])) .+ NB[1]) + + pA′, pB′, pAB′ = reorder_planar_indices(indA, pA, indB, pB, pAB) + @test pA′ == ((1, 2, 3), (4, 5, 6)) + @test pB′ == ((1, 2, 3), (4, 5, 6)) + @test pAB′ == ((1, 2, 3), (4, 5, 6)) + end + + @testset "swap uncontracted inds" begin + pA = ((2, 1, 3), (4, 5)) + pB = ((1, 2), (3, 4, 5)) + pAB = ((2, 1, 3), (4, 5, 6)) + NA = (3, 2) + NB = (2, 3) + indA = (ntuple(identity, NA[1]), reverse(ntuple(identity, NA[2])) .+ NA[1]) + indB = (ntuple(identity, NB[1]), reverse(ntuple(identity, NB[2])) .+ NB[1]) + + pA′, pB′, pAB′ = reorder_planar_indices(indA, pA, indB, pB, pAB) + @test (TupleTools.getindices(pA[1], (2, 1, 3)), pA[2]) == pA′ + @test pB == pB′ + @test (TupleTools.getindices(pAB[1], (2, 1, 3)), pAB[2]) == pAB′ + end + + @testset "non-planar contraction" begin + pA = ((2, 1, 3), (4, 5)) + pB = ((1, 2), (3, 4, 5)) + pAB = ((1, 2, 3), (4, 5, 6)) + NA = (3, 2) + NB = (2, 3) + indA = (ntuple(identity, NA[1]), reverse(ntuple(identity, NA[2])) .+ NA[1]) + indB = (ntuple(identity, NB[1]), reverse(ntuple(identity, NB[2])) .+ NB[1]) + + @test_throws AssertionError reorder_planar_indices(indA, pA, indB, pB, pAB) + end + + @testset "non-planar contraction" begin + pA = ((1, 2, 3), (5, 4)) + pB = ((1, 2), (3, 4, 5)) + pAB = ((1, 2, 3), (4, 5, 6)) + NA = (3, 2) + NB = (2, 3) + indA = (ntuple(identity, NA[1]), reverse(ntuple(identity, NA[2])) .+ NA[1]) + indB = (ntuple(identity, NB[1]), reverse(ntuple(identity, NB[2])) .+ NB[1]) + + @test_throws AssertionError reorder_planar_indices(indA, pA, indB, pB, pAB) + end + + @testset "change input tensor partitions" begin + pA = ((1, 2, 3), (5, 4)) + pB = ((1, 2), (3, 4, 5)) + pAB = ((1, 2, 3), (4, 5, 6)) + NA = (5, 0) + NB = (2, 3) + indA = (ntuple(identity, NA[1]), reverse(ntuple(identity, NA[2])) .+ NA[1]) + indB = (ntuple(identity, NB[1]), reverse(ntuple(identity, NB[2])) .+ NB[1]) + + pA′, pB′, pAB′ = reorder_planar_indices(indA, pA, indB, pB, pAB) + @test pA′ == ((1, 2, 3), (5, 4)) + @test pB′ == ((1, 2), (3, 4, 5)) + @test pAB′ == ((1, 2, 3), (4, 5, 6)) + end + + @testset "edge case with no contracted indices" begin + pA = ((1, 2, 3), ()) + pB = ((), (1, 2, 3)) + pAB = ((1, 2, 3), (4, 5, 6)) + NA = (3, 0) + NB = (0, 3) + indA = (ntuple(identity, NA[1]), reverse(ntuple(identity, NA[2])) .+ NA[1]) + indB = (ntuple(identity, NB[1]), reverse(ntuple(identity, NB[2])) .+ NB[1]) + + pA′, pB′, pAB′ = reorder_planar_indices(indA, pA, indB, pB, pAB) + @test pA′ == pA + @test pB == pB′ + @test pAB == pAB′ + end + + @testset "edge case with no contracted indices" begin + pA = ((2, 3, 1), ()) + pB = ((), (1, 2, 3)) + pAB = ((1, 2, 3), (4, 5, 6)) + NA = (3, 0) + NB = (0, 3) + indA = (ntuple(identity, NA[1]), reverse(ntuple(identity, NA[2])) .+ NA[1]) + indB = (ntuple(identity, NB[1]), reverse(ntuple(identity, NB[2])) .+ NB[1]) + + pA′, pB′, pAB′ = reorder_planar_indices(indA, pA, indB, pB, pAB) + @test pA′ == pA + @test pB == pB′ + @test pAB == pAB′ + end + + @testset "edge case with no contracted indices" begin + pA = ((1, 2, 3), ()) + pB = ((), (1, 2, 3)) + pAB = ((2, 3, 1), (4, 5, 6)) + NA = (3, 0) + NB = (0, 3) + indA = (ntuple(identity, NA[1]), reverse(ntuple(identity, NA[2])) .+ NA[1]) + indB = (ntuple(identity, NB[1]), reverse(ntuple(identity, NB[2])) .+ NB[1]) + + pA′, pB′, pAB′ = reorder_planar_indices(indA, pA, indB, pB, pAB) + @test pA′ == ((2, 3, 1), ()) + @test pB == pB′ + @test pAB′ == ((1, 2, 3), (4, 5, 6)) + end +end + @testset "planar methods" verbose = true begin @testset "planaradd" begin A = TensorMap(randn, ℂ^2 ⊗ ℂ^3 ← ℂ^6 ⊗ ℂ^5 ⊗ ℂ^4) From 31bd0b30310d2ad96f9266db82ee08af823ef44a Mon Sep 17 00:00:00 2001 From: lkdvos Date: Tue, 30 Jan 2024 09:35:25 +0100 Subject: [PATCH 06/29] Fix some things and make planar tests run --- src/TensorKit.jl | 2 +- src/planar/planaroperations.jl | 91 +++++++++++++++++++++------------- src/tensors/braidingtensor.jl | 35 +++++++------ test/planar.jl | 4 +- 4 files changed, 80 insertions(+), 52 deletions(-) diff --git a/src/TensorKit.jl b/src/TensorKit.jl index a81aabe32..5aaa68924 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -87,7 +87,7 @@ export notrunc, truncerr, truncdim, truncspace, truncbelow # Imports #--------- using TupleTools -using TupleTools: StaticLength +using TupleTools: StaticLength, getindices using Strided diff --git a/src/planar/planaroperations.jl b/src/planar/planaroperations.jl index 6b28a9980..8676bb79a 100644 --- a/src/planar/planaroperations.jl +++ b/src/planar/planaroperations.jl @@ -59,7 +59,7 @@ function planarcontract!(C::AbstractTensorMap{S}, indA = (codomainind(A), reverse(domainind(A))) indB = (codomainind(B), reverse(domainind(B))) - pA′, pB′, pAB′ = reorder_indices(indA, pA, indB, pB, pAB) + pA′, pB′, pAB′ = reorder_planar_indices(indA, pA, indB, pB, pAB) if pA′ == (codomainind(A), domainind(A)) @@ -77,8 +77,8 @@ function planarcontract!(C::AbstractTensorMap{S}, end ipAB = TupleTools.invperm(linearize(pAB′)) - oindAinC = TupleTools.getindices(ipAB, ntuple(n -> n, N₁)) - oindBinC = TupleTools.getindices(ipAB, ntuple(n -> n + N₁, N₂)) + oindAinC = getindices(ipAB, ntuple(n -> n, N₁)) + oindBinC = getindices(ipAB, ntuple(n -> n + N₁, N₂)) if has_shared_permute(C, (oindAinC, oindBinC)) C′ = transpose(C, (oindAinC, oindBinC)) @@ -107,6 +107,10 @@ end _cyclicpermute(t::Tuple) = (Base.tail(t)..., t[1]) _cyclicpermute(t::Tuple{}) = () +_circshift(::Tuple{}, ::Int) = () +_circshift(t::Tuple, n::Int) = ntuple(i -> t[mod1(i - n, length(t))], length(t)) + +_indexin(v1, v2) = ntuple(n -> findfirst(isequal(v1[n]), v2), length(v1)) function _iscyclicpermutation(v1, v2) length(v1) == length(v2) || return false @@ -116,8 +120,9 @@ end function _findsetcircshift(p_cyclic, p_subset) N = length(p_cyclic) M = length(p_subset) + N == M == 0 && return 0 i = findfirst(0:(N - 1)) do i - return issetequal(TupleTools.getindices(p_cyclic, ntuple(n -> mod1(n + i, N), M)), + return issetequal(getindices(p_cyclic, ntuple(n -> mod1(n + i, N), M)), p_subset) end isnothing(i) && throw(ArgumentError("no cyclic permutation of $p_cyclic that matches $p_subset")) @@ -142,46 +147,62 @@ function reorder_planar_indices(indA, pA, indB, pB, pAB) @assert NAB == NA₁ + NB₂ # find circshift index of pAB if considered as shifting sets - indAB = (ntuple(identity, NAB₁), - reverse(ntuple(n -> n + NAB₁, NAB₂))) - indAB_lin = (indAB[1]..., indAB[2]...) - iAB = _findsetcircshift(indAB_lin, pAB[1]) - @assert iAB == _findsetcircshift(indAB_lin, pAB[2]) - NAB₁ "sanity check" - - # migrate permutations from pAB to pA and pB - permA = TupleTools.getindices((pAB[1]..., reverse(pAB[2])...), - ntuple(n -> mod1(n + iAB, NAB), NA₁)) - permB = reverse(TupleTools.getindices((pAB[1]..., reverse(pAB[2])...), - ntuple(n -> mod1(n + iAB + NA₁, NAB), NB₂)) .- - NA₁) - - pA′ = (TupleTools.getindices(pA[1], permA), pA[2]) - pB′ = (pB[1], reverse(TupleTools.getindices(reverse(pB[2]), permB))) - pAB′ = (ntuple(n -> n + iAB, NAB₁), - ntuple(n -> n + iAB + NAB₁, NAB₂)) - - # fix permutations of contracted indices - if NA₂ > 0 + indAB = (ntuple(identity, NAB₁), reverse(ntuple(n -> n + NAB₁, NAB₂))) + + if NAB > 0 + indAB_lin = (indAB[1]..., indAB[2]...) + iAB = _findsetcircshift(indAB_lin, pAB[1]) + @assert iAB == mod(_findsetcircshift(indAB_lin, pAB[2]) - NAB₁, NAB) "sanity check" + + # migrate permutations from pAB to pA and pB + permA = getindices((pAB[1]..., reverse(pAB[2])...), + ntuple(n -> mod1(n + iAB, NAB), NA₁)) + permB = reverse(getindices((pAB[1]..., reverse(pAB[2])...), + ntuple(n -> mod1(n + iAB + NA₁, NAB), NB₂)) .- NA₁) + + pA′ = (getindices(pA[1], permA), pA[2]) + pB′ = (pB[1], getindices(pB[2], permB)) + pAB′ = (ntuple(n -> n + iAB, NAB₁), ntuple(n -> n + iAB + NAB₁, NAB₂)) + else + pA′ = pA + pB′ = pB + pAB′ = pAB + end + + # cycle indA to be of the form (oindA..., reverse(cindA)...) + if NA₁ != 0 indA_lin = (indA[1]..., indA[2]...) - iA = _findsetcircshift(indA_lin, pA′[1]) - @assert iA == _findsetcircshift(indA_lin, pA′[2]) - NA₁ "sanity check" + iA = findfirst(==(first(pA′[1])), indA_lin) + @assert all(indA_lin[mod1.(iA .+ (1:NA₁) .- 1, NA)] .== pA′[1]) "sanity check" pA′ = (pA′[1], - reverse(TupleTools.getindices(linearize(indA), - ntuple(n -> mod1(n + iA + NA₁, NA), NA₂)))) - + reverse(getindices(indA_lin, ntuple(n -> mod1(n + iA + NA₁ - 1, NA), NA₂)))) + end + # cycle indB to be of the form (cindB..., reverse(oindB)...) + if NB₂ != 0 indB_lin = (indB[1]..., indB[2]...) - iB = _findsetcircshift(indB_lin, pB′[1]) - @assert iB == _findsetcircshift(indB_lin, pB′[2]) - NB₁ "sanity check" - pB′ = (TupleTools.getindices(linearize(indB), - ntuple(n -> mod1(n + iB, NB), NB₁)), - pB′[2]) + iB = findfirst(==(first(pB′[2])), indB_lin) + @assert all(indB_lin[mod1.(iB .- (1:NB₂) .+ 1, NB)] .== (pB′[2])) "$pB $pB′ $indB_lin $iB" + pB′ = (getindices(indB_lin, ntuple(n -> mod1(n + iB, NB), NB₁)), pB′[2]) + end + + # if uncontracted indices are empty, we can still make cyclic adjustments + if NA₁ == 0 + shiftA = findfirst(==(first(pB′[1])), pB[1]) + @assert !isnothing(shiftA) "pB = $pB, pB′ = $pB′" + pA′ = (pA′[1], _circshift(pA′[2], shiftA-1)) + end + + if NB₂ == 0 + shiftB = findfirst(==(first(pA′[2])), pA[2]) + @assert !isnothing(shiftB) "pA = $pA, pA′ = $pA′" + pB′ = (_circshift(pB′[1], shiftB-1), pB′[2]) end # make sure this is still the same contraction @assert issetequal(pA[1], pA′[1]) && issetequal(pA[2], pA′[2]) @assert issetequal(pB[1], pB′[1]) && issetequal(pB[2], pB′[2]) @assert issetequal(pAB[1], pAB′[1]) && issetequal(pAB[2], pAB′[2]) - @assert issetequal(tuple.(pA[2], pB[1]), tuple.(pA′[2], pB′[1])) + @assert issetequal(tuple.(pA[2], pB[1]), tuple.(pA′[2], pB′[1])) "$pA $pB $pA′ $pB′" # make sure that everything is now planar @assert _iscyclicpermutation((indA[1]..., (indA[2])...), diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index 4b8489d6c..c41b26bcb 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -227,20 +227,22 @@ end function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, A::BraidingTensor{S}, - (oindA, cindA)::Index2Tuple{2,2}, + pA::Index2Tuple{2,2}, B::AbstractTensorMap{S}, - (cindB, oindB)::Index2Tuple{2,N₃}, - (p1, p2)::Index2Tuple{N₁,N₂}, + pB::Index2Tuple{2,N₃}, + pAB::Index2Tuple{N₁,N₂}, α::Number, β::Number, backend::Backend...) where {S,N₁,N₂,N₃} - codA, domA = codomainind(A), domainind(A) - codB, domB = codomainind(B), domainind(B) - oindA, cindA, oindB, cindB = reorder_indices(codA, domA, codB, domB, oindA, cindA, - oindB, cindB, p1, p2) + + indA = (codomainind(A), reverse(domainind(A))) + indB = (codomainind(B), reverse(domainind(B))) + pA, pB, pAB = reorder_planar_indices(indA, pA, indB, pB, pAB) + oindA, cindA = pA + cindB, oindB = pB if space(B, cindB[1]) != space(A, cindA[1])' || space(B, cindB[2]) != space(A, cindA[2])' - throw(SpaceMismatch("$(space(C)) ≠ permute($(space(A))[$oindA, $cindA] * $(space(B))[$cindB, $oindB], ($p1, $p2)")) + throw(SpaceMismatch("$(space(C)) ≠ permute($(space(A))[$oindA, $cindA] * $(space(B))[$cindB, $oindB], ($pAB)")) end if BraidingStyle(sectortype(B)) isa Bosonic @@ -273,20 +275,25 @@ function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, end function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, A::AbstractTensorMap{S}, - (oindA, cindA)::Index2Tuple{N₃,2}, + pA::Index2Tuple{N₃,2}, B::BraidingTensor{S}, - (cindB, oindB)::Index2Tuple{2,2}, - (p1, p2)::Index2Tuple{N₁,N₂}, + pB::Index2Tuple{2,2}, + pAB::Index2Tuple{N₁,N₂}, α::Number, β::Number, backend::Backend...) where {S,N₁,N₂,N₃} codA, domA = codomainind(A), domainind(A) codB, domB = codomainind(B), domainind(B) - oindA, cindA, oindB, cindB = reorder_indices(codA, domA, codB, domB, oindA, cindA, - oindB, cindB, p1, p2) + + indA = (codomainind(A), reverse(domainind(A))) + indB = (codomainind(B), reverse(domainind(B))) + pA, pB, pAB = reorder_planar_indices(indA, pA, indB, pB, pAB) + oindA, cindA = pA + cindB, oindB = pB + if space(B, cindB[1]) != space(A, cindA[1])' || space(B, cindB[2]) != space(A, cindA[2])' - throw(SpaceMismatch("$(space(C)) ≠ permute($(space(A))[$oindA, $cindA] * $(space(B))[$cindB, $oindB], ($p1, $p2)")) + throw(SpaceMismatch("$(space(C)) ≠ permute($(space(A))[$oindA, $cindA] * $(space(B))[$cindB, $oindB], ($pAB)")) end if BraidingStyle(sectortype(A)) isa Bosonic diff --git a/test/planar.jl b/test/planar.jl index c1297a68a..ce8278e29 100644 --- a/test/planar.jl +++ b/test/planar.jl @@ -116,9 +116,9 @@ using TensorKit: reorder_planar_indices indB = (ntuple(identity, NB[1]), reverse(ntuple(identity, NB[2])) .+ NB[1]) pA′, pB′, pAB′ = reorder_planar_indices(indA, pA, indB, pB, pAB) - @test (TupleTools.getindices(pA[1], (2, 1, 3)), pA[2]) == pA′ + @test pA′ == ((1, 2, 3), (4, 5)) @test pB == pB′ - @test (TupleTools.getindices(pAB[1], (2, 1, 3)), pAB[2]) == pAB′ + @test pAB′ == ((1, 2, 3), (4, 5, 6)) end @testset "non-planar contraction" begin From 13e4cdfe5aca4c15ec156deedb85735e1150225c Mon Sep 17 00:00:00 2001 From: lkdvos Date: Tue, 30 Jan 2024 11:04:22 +0100 Subject: [PATCH 07/29] more fixes and rewrites --- src/planar/planaroperations.jl | 63 ++++++++++++++++++---------------- test/planar.jl | 13 +++++++ 2 files changed, 46 insertions(+), 30 deletions(-) diff --git a/src/planar/planaroperations.jl b/src/planar/planaroperations.jl index 8676bb79a..ec4dfd962 100644 --- a/src/planar/planaroperations.jl +++ b/src/planar/planaroperations.jl @@ -61,7 +61,6 @@ function planarcontract!(C::AbstractTensorMap{S}, indB = (codomainind(B), reverse(domainind(B))) pA′, pB′, pAB′ = reorder_planar_indices(indA, pA, indB, pB, pAB) - if pA′ == (codomainind(A), domainind(A)) A′ = A else @@ -125,8 +124,9 @@ function _findsetcircshift(p_cyclic, p_subset) return issetequal(getindices(p_cyclic, ntuple(n -> mod1(n + i, N), M)), p_subset) end - isnothing(i) && throw(ArgumentError("no cyclic permutation of $p_cyclic that matches $p_subset")) - return i-1::Int + isnothing(i) && + throw(ArgumentError("no cyclic permutation of $p_cyclic that matches $p_subset")) + return i - 1::Int end function reorder_planar_indices(indA, pA, indB, pB, pAB) @@ -139,7 +139,7 @@ function reorder_planar_indices(indA, pA, indB, pB, pAB) NAB₁ = length(pAB[1]) NAB₂ = length(pAB[2]) NAB = NAB₁ + NAB₂ - + # input checks @assert NA == length(indA[1]) + length(indA[2]) @assert NB == length(indB[1]) + length(indB[2]) @@ -148,7 +148,7 @@ function reorder_planar_indices(indA, pA, indB, pB, pAB) # find circshift index of pAB if considered as shifting sets indAB = (ntuple(identity, NAB₁), reverse(ntuple(n -> n + NAB₁, NAB₂))) - + if NAB > 0 indAB_lin = (indAB[1]..., indAB[2]...) iAB = _findsetcircshift(indAB_lin, pAB[1]) @@ -156,9 +156,9 @@ function reorder_planar_indices(indA, pA, indB, pB, pAB) # migrate permutations from pAB to pA and pB permA = getindices((pAB[1]..., reverse(pAB[2])...), - ntuple(n -> mod1(n + iAB, NAB), NA₁)) + ntuple(n -> mod1(n + iAB, NAB), NA₁)) permB = reverse(getindices((pAB[1]..., reverse(pAB[2])...), - ntuple(n -> mod1(n + iAB + NA₁, NAB), NB₂)) .- NA₁) + ntuple(n -> mod1(n + iAB + NA₁, NAB), NB₂)) .- NA₁) pA′ = (getindices(pA[1], permA), pA[2]) pB′ = (pB[1], getindices(pB[2], permB)) @@ -170,32 +170,35 @@ function reorder_planar_indices(indA, pA, indB, pB, pAB) end # cycle indA to be of the form (oindA..., reverse(cindA)...) + indA_lin = (indA[1]..., indA[2]...) if NA₁ != 0 - indA_lin = (indA[1]..., indA[2]...) - iA = findfirst(==(first(pA′[1])), indA_lin) - @assert all(indA_lin[mod1.(iA .+ (1:NA₁) .- 1, NA)] .== pA′[1]) "sanity check" - pA′ = (pA′[1], - reverse(getindices(indA_lin, ntuple(n -> mod1(n + iA + NA₁ - 1, NA), NA₂)))) + iA = findfirst(==(first(pA′[1])), indA_lin) - 1 + indA_lin = _circshift(indA_lin, -iA) end + pc = ntuple(identity, NA₂) + @assert all(getindices(indA_lin, ntuple(identity, NA₁)) .== pA′[1]) "sanity check" + pA′ = (pA′[1], reverse(getindices(indA_lin, pc .+ NA₁))) + # cycle indB to be of the form (cindB..., reverse(oindB)...) + indB_lin = (indB[1]..., indB[2]...) if NB₂ != 0 - indB_lin = (indB[1]..., indB[2]...) iB = findfirst(==(first(pB′[2])), indB_lin) - @assert all(indB_lin[mod1.(iB .- (1:NB₂) .+ 1, NB)] .== (pB′[2])) "$pB $pB′ $indB_lin $iB" - pB′ = (getindices(indB_lin, ntuple(n -> mod1(n + iB, NB), NB₁)), pB′[2]) + indB_lin = _circshift(indB_lin, -iB) end - + @assert all(getindices(indB_lin, ntuple(identity, NB₂) .+ NB₁) .== reverse(pB′[2])) "sanity check" + pB′ = (getindices(indB_lin, pc), pB′[2]) + # if uncontracted indices are empty, we can still make cyclic adjustments - if NA₁ == 0 - shiftA = findfirst(==(first(pB′[1])), pB[1]) - @assert !isnothing(shiftA) "pB = $pB, pB′ = $pB′" - pA′ = (pA′[1], _circshift(pA′[2], shiftA-1)) + if NA₁ == 0 && NA₂ != 0 + hit = pA[2][findfirst(==(first(pB′[1])), pB[1])] + shiftA = findfirst(==(hit), pA′[2]) - 1 + pA′ = (pA′[1], _circshift(pA′[2], -shiftA)) end - - if NB₂ == 0 - shiftB = findfirst(==(first(pA′[2])), pA[2]) - @assert !isnothing(shiftB) "pA = $pA, pA′ = $pA′" - pB′ = (_circshift(pB′[1], shiftB-1), pB′[2]) + + if NB₂ == 0 && NB₁ != 0 + hit = pB[1][findfirst(==(first(pA′[2])), pA[2])] + shiftB = findfirst(==(hit), pB′[1]) - 1 + pB′ = (_circshift(pB′[1], -shiftB), pB′[2]) end # make sure this is still the same contraction @@ -203,14 +206,14 @@ function reorder_planar_indices(indA, pA, indB, pB, pAB) @assert issetequal(pB[1], pB′[1]) && issetequal(pB[2], pB′[2]) @assert issetequal(pAB[1], pAB′[1]) && issetequal(pAB[2], pAB′[2]) @assert issetequal(tuple.(pA[2], pB[1]), tuple.(pA′[2], pB′[1])) "$pA $pB $pA′ $pB′" - + # make sure that everything is now planar @assert _iscyclicpermutation((indA[1]..., (indA[2])...), - (pA′[1]..., reverse(pA′[2])...)) + (pA′[1]..., reverse(pA′[2])...)) "indA = $indA, pA′ = $pA′" @assert _iscyclicpermutation((indB[1]..., (indB[2])...), - (pB′[1]..., reverse(pB′[2])...)) + (pB′[1]..., reverse(pB′[2])...)) "indB = $indB, pB′ = $pB′" @assert _iscyclicpermutation((indAB[1]..., (indAB[2])...), - (pAB′[1]..., reverse(pAB′[2])...)) - + (pAB′[1]..., reverse(pAB′[2])...)) "indAB = $indAB, pAB′ = $pAB′" + return pA′, pB′, pAB′ end diff --git a/test/planar.jl b/test/planar.jl index ce8278e29..358b1137d 100644 --- a/test/planar.jl +++ b/test/planar.jl @@ -204,6 +204,19 @@ using TensorKit: reorder_planar_indices @test pB == pB′ @test pAB′ == ((1, 2, 3), (4, 5, 6)) end + + @testset "edge case with no uncontracted indices right" begin + pA = ((), (1, 2, 3)) + pB = ((2, 1, 3), ()) + pAB = ((), ()) + indA = ((1,), (3, 2)) + indB = ((1,), (3, 2)) + + pA′, pB′, pAB′ = reorder_planar_indices(indA, pA, indB, pB, pAB) + @test pA′ == ((), (2, 3, 1)) + @test pB′ == ((1, 3, 2), ()) + @test pAB′ == ((), ()) + end end @testset "planar methods" verbose = true begin From 3ed5ecf6739bf9766e2e99f0d1f653bb327482f0 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Tue, 30 Jan 2024 13:05:45 +0100 Subject: [PATCH 08/29] more fixes and tests --- src/planar/planaroperations.jl | 22 +++++++++++++--------- test/planar.jl | 10 ++++++++++ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/planar/planaroperations.jl b/src/planar/planaroperations.jl index ec4dfd962..8c8760127 100644 --- a/src/planar/planaroperations.jl +++ b/src/planar/planaroperations.jl @@ -60,7 +60,7 @@ function planarcontract!(C::AbstractTensorMap{S}, indA = (codomainind(A), reverse(domainind(A))) indB = (codomainind(B), reverse(domainind(B))) pA′, pB′, pAB′ = reorder_planar_indices(indA, pA, indB, pB, pAB) - + if pA′ == (codomainind(A), domainind(A)) A′ = A else @@ -84,6 +84,7 @@ function planarcontract!(C::AbstractTensorMap{S}, mul!(C′, A′, B′, α, β) else C′ = A′ * B′ + add_transpose!(C, C′, pAB′, α, β) end @@ -147,22 +148,25 @@ function reorder_planar_indices(indA, pA, indB, pB, pAB) @assert NAB == NA₁ + NB₂ # find circshift index of pAB if considered as shifting sets - indAB = (ntuple(identity, NAB₁), reverse(ntuple(n -> n + NAB₁, NAB₂))) + indAB = (ntuple(identity, NA₁), reverse(ntuple(n -> n + NA₁, NB₂))) if NAB > 0 indAB_lin = (indAB[1]..., indAB[2]...) iAB = _findsetcircshift(indAB_lin, pAB[1]) @assert iAB == mod(_findsetcircshift(indAB_lin, pAB[2]) - NAB₁, NAB) "sanity check" - + indAB_lin = _circshift(indAB_lin, -iAB) # migrate permutations from pAB to pA and pB - permA = getindices((pAB[1]..., reverse(pAB[2])...), - ntuple(n -> mod1(n + iAB, NAB), NA₁)) - permB = reverse(getindices((pAB[1]..., reverse(pAB[2])...), - ntuple(n -> mod1(n + iAB + NA₁, NAB), NB₂)) .- NA₁) + + pAB_lin = (pAB[1]..., reverse(pAB[2])...) + permA = getindices(pAB_lin, + ntuple(n -> mod1(n - iAB, NAB), NA₁)) + permB = reverse(getindices(pAB_lin, + ntuple(n -> mod1(n - iAB + NA₁, NAB), NB₂)) .- NA₁) pA′ = (getindices(pA[1], permA), pA[2]) pB′ = (pB[1], getindices(pB[2], permB)) - pAB′ = (ntuple(n -> n + iAB, NAB₁), ntuple(n -> n + iAB + NAB₁, NAB₂)) + pAB′ = (getindices(indAB_lin, ntuple(n -> mod1(n, NAB), NAB₁)), + reverse(getindices(indAB_lin, ntuple(n -> mod1(n + NAB₁, NAB), NAB₂)))) else pA′ = pA pB′ = pB @@ -204,7 +208,7 @@ function reorder_planar_indices(indA, pA, indB, pB, pAB) # make sure this is still the same contraction @assert issetequal(pA[1], pA′[1]) && issetequal(pA[2], pA′[2]) @assert issetequal(pB[1], pB′[1]) && issetequal(pB[2], pB′[2]) - @assert issetequal(pAB[1], pAB′[1]) && issetequal(pAB[2], pAB′[2]) + # @assert issetequal(pAB[1], pAB′[1]) && issetequal(pAB[2], pAB′[2]) "pAB = $pAB, pAB′ = $pAB′" @assert issetequal(tuple.(pA[2], pB[1]), tuple.(pA′[2], pB′[1])) "$pA $pB $pA′ $pB′" # make sure that everything is now planar diff --git a/test/planar.jl b/test/planar.jl index 358b1137d..5b3ae9868 100644 --- a/test/planar.jl +++ b/test/planar.jl @@ -217,6 +217,16 @@ using TensorKit: reorder_planar_indices @test pB′ == ((1, 3, 2), ()) @test pAB′ == ((), ()) end + + @testset "something" begin + pA, pB, pAB = (((1, 2, 3), (4, 5)), ((3, 1), (4, 2)), ((3, 5), (2, 1, 4))) + indA = ((1, 2, 3), (5, 4)) + indB = ((1, 2), (4, 3)) + pA′, pB′, pAB′ = reorder_planar_indices(indA, pA, indB, pB, pAB) + @test pA′ == pA + @test pB′ == pB + @test pAB′ == pAB + end end @testset "planar methods" verbose = true begin From d38867d4faf3e5b8b431da401ec5c50803ba2687 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Tue, 30 Jan 2024 13:07:59 +0100 Subject: [PATCH 09/29] Ad improvements --- ext/TensorKitChainRulesCoreExt.jl | 57 +++++++++++++++++++------------ test/ad.jl | 51 +++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 22 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index a3f02dff0..322473aed 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -1,17 +1,21 @@ module TensorKitChainRulesCoreExt using TensorOperations +using TensorOperations: Backend, promote_contract using TensorKit using TensorKit: planaradd!, planarcontract!, planarcontract, _canonicalize +using VectorInterface using ChainRulesCore using LinearAlgebra using TupleTools +using TupleTools: getindices # Utility # ------- _conj(conjA::Symbol) = conjA == :C ? :N : :C trivtuple(N) = ntuple(identity, N) +trivtuple(::Index2Tuple{N₁,N₂}) where {N₁,N₂} = trivtuple(N₁ + N₂) function _repartition(p::IndexTuple, N₁::Int) length(p) >= N₁ || @@ -113,8 +117,7 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe ipA = (codomainind(A), domainind(A)) pB = (allind(B), ()) dA = zerovector(A, - TensorOperations.promote_contract(scalartype(ΔC), - scalartype(B))) + promote_contract(scalartype(ΔC), scalartype(B))) dA = tensorcontract!(dA, ipA, ΔC, pΔC, :N, B, pB, :C) return projectA(dA) end @@ -122,8 +125,7 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe ipB = (codomainind(B), domainind(B)) pA = ((), allind(A)) dB = zerovector(B, - TensorOperations.promote_contract(scalartype(ΔC), - scalartype(A))) + promote_contract(scalartype(ΔC), scalartype(A))) dB = tensorcontract!(dB, ipB, A, pA, :C, ΔC, pΔC, :N) return projectB(dB) end @@ -650,19 +652,21 @@ function ChainRulesCore.rrule(::typeof(TensorKit.planaradd!), C::AbstractTensorM dA = @thunk begin ip = _canonicalize(invperm(linearize(p)), A) _dA = zerovector(A, VectorInterface.promote_add(ΔC, α)) - _dA = planaradd!(_dA, ip, ΔC, conj(α), Zero(), backend...) + _dA = planaradd!(_dA, ΔC, ip, conj(α), Zero(), backend...) return projectA(_dA) end dα = @thunk begin - _dα = tensorscalar(planarcontract(A, ((), linearize(p)), - ΔC, (trivtuple(numind(p)), ()), - ((), ()), One(), Zero(), backend...)) + p′ = TensorKit.adjointtensorindices(A, p) + _dα = tensorscalar(planarcontract(A', ((), linearize(p′)), + ΔC, (trivtuple(p), ()), + ((), ()), One(), backend...)) return projectα(_dα) end dβ = @thunk begin - _dβ = tensorscalar(planarcontract(C, ((), trivtuple(numind(pC))), - ΔC, (trivtuple(numind(pC)), ()), - ((), ()), One(), Zero(), backend...)) + p′ = TensorKit.adjointtensorindices(C, trivtuple(p)) + _dβ = tensorscalar(planarcontract(C', ((), p′), + ΔC, (trivtuple(p), ()), + ((), ()), One(), backend...)) return projectβ(_dβ) end dbackend = map(x -> NoTangent(), backend) @@ -673,11 +677,14 @@ function ChainRulesCore.rrule(::typeof(TensorKit.planaradd!), C::AbstractTensorM end function ChainRulesCore.rrule(::typeof(TensorKit.planarcontract!), - C::AbstractTensorMap{S,N₁,N₂}, + C::AbstractTensorMap{S,N₁,N₂}, A::AbstractTensorMap{S}, pA::Index2Tuple, B::AbstractTensorMap{S}, pB::Index2Tuple, pAB::Index2Tuple{N₁,N₂}, α::Number, β::Number, backend::Backend...) where {S,N₁,N₂} + indA = (codomainind(A), reverse(domainind(A))) + indB = (codomainind(B), reverse(domainind(B))) + pA, pB, pAB = TensorKit.reorder_planar_indices(indA, pA, indB, pB, pAB) C′ = planarcontract!(copy(C), A, pA, B, pB, pAB, α, β, backend...) projectA = ProjectTo(A) @@ -688,33 +695,39 @@ function ChainRulesCore.rrule(::typeof(TensorKit.planarcontract!), function planarcontract_pullback(ΔC′) ΔC = unthunk(ΔC′) - pΔC = _canonicalize(invperm(linearize(pAB)), ΔC) + ipAB = invperm(linearize(pAB)) + pΔC = (getindices(ipAB, trivtuple(length(pA[1]))), + getindices(ipAB, length(pA[1]) .+ trivtuple(length(pB[2])))) dC = @thunk projectC(scale(ΔC, conj(β))) dA = @thunk begin - ipA = _canonicalize(invperm(linearize(pA)), (), A) + ipA = _canonicalize(invperm(linearize(pA)), A) _dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B), typeof(α))) - _dA = planarcontract!(_dA, ΔC, pΔC, adjoint(B), reverse(pB), ipA, + pB′ = TensorKit.adjointtensorindices(B, reverse(pB)) + _dA = planarcontract!(_dA, ΔC, pΔC, adjoint(B), pB′, ipA, conj(α), Zero(), backend...) return projectA(_dA) end dB = @thunk begin ipB = _canonicalize((invperm(linearize(pB)), ()), B) _dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A), typeof(α))) - _dB = planarcontract!(_dB, A', reverse(pA), ΔC, pΔC, ipB, + pA′ = TensorKit.adjointtensorindices(A, reverse(pA)) + _dB = planarcontract!(_dB, adjoint(A), pA′, ΔC, pΔC, ipB, conj(α), Zero(), backend...) return projectB(_dB) end dα = @thunk begin AB = planarcontract!(similar(C), A, pA, B, pB, pAB, One(), Zero(), backend...) - _dα = tensorscalar(planarcontract(AB', ((), trivtuple(numind(pAB))), - ΔC, (trivtuple(numind(pAB)), ()), ((), ()), - One(), Zero(), backend...)) + p′ = TensorKit.adjointtensorindices(AB, trivtuple(pAB)) + _dα = tensorscalar(planarcontract(AB', ((), p′), + ΔC, (trivtuple(pAB), ()), ((), ()), + One(), backend...)) return projectα(_dα) end dβ = @thunk begin - _dβ = tensorscalar(planarcontract(C', ((), trivtuple(numind(pAB))), - ΔC, (trivtuple(numind(pAB)), ()), ((), ()), - One(), Zero(), backend...)) + p′ = TensorKit.adjointtensorindices(C, trivtuple(pAB)) + _dβ = tensorscalar(planarcontract(C', ((), p′), + ΔC, (trivtuple(pAB), ()), ((), ()), + One(), backend...)) return projectβ(_dβ) end dbackend = map(x -> NoTangent(), backend) diff --git a/test/ad.jl b/test/ad.jl index d686a8341..9587916dd 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -3,6 +3,7 @@ using ChainRulesTestUtils using Random using FiniteDifferences using LinearAlgebra +using TensorKit: ℙ, planaradd!, planarcontract! ## Test utility # ------------- @@ -106,11 +107,17 @@ end ChainRulesTestUtils.test_method_tables() Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + (ℙ^2, (ℙ^3)', ℙ^3, ℙ^2, (ℙ^2)'), (ℂ[Z2Irrep](0 => 1, 1 => 1), ℂ[Z2Irrep](0 => 1, 1 => 2)', ℂ[Z2Irrep](0 => 3, 1 => 2)', ℂ[Z2Irrep](0 => 2, 1 => 3), ℂ[Z2Irrep](0 => 2, 1 => 2)), + (ℂ[FermionParity](0 => 1, 1 => 1), + ℂ[FermionParity](0 => 1, 1 => 2)', + ℂ[FermionParity](0 => 3, 1 => 2)', + ℂ[FermionParity](0 => 2, 1 => 3), + ℂ[FermionParity](0 => 2, 1 => 2)), (ℂ[U1Irrep](0 => 2, 1 => 2, -1 => 2), ℂ[U1Irrep](0 => 3, 1 => 1, -1 => 1), ℂ[U1Irrep](0 => 2, 1 => 2, -1 => 1)', @@ -221,6 +228,50 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), end end + @testset "PlanarOperations with scalartype $T" for T in (Float64, ComplexF64) + atol = precision(T) + rtol = precision(T) + + @testset "planaradd!" begin + p = ((4, 3, 1), (5, 2)) + A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + C = _randomize!(TensorOperations.tensoralloc_add(T, p, A, :N, false)) + α = randn(T) + β = randn(T) + test_rrule(planaradd!, C, A, p, α, β; atol, rtol) + end + + @testset "planarcontract! 1" begin + A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + B = TensorMap(randn, T, V[1] ⊗ V[5] ← V[5] ⊗ V[2]) + pA = ((1, 3, 4), (5, 2)) + pB = ((2, 4), (1, 3)) + pAB = ((3, 2, 1), (4, 5)) + + α = randn(T) + β = randn(T) + + C = _randomize!(TensorOperations.tensoralloc_contract(T, pAB, A, pA, :N, + B, pB, :N, false)) + test_rrule(planarcontract!, C, A, pA, B, pB, pAB, α, β; atol, rtol) + end + + @testset "planarcontract! 2" begin + A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + B = TensorMap(randn, T, V[3] ⊗ V[4] ⊗ V[5] ← V[1] ⊗ V[2]) + pA = ((1, 2), (3, 4, 5)) + pB = ((1, 2, 3), (4, 5)) + pAB = ((1, 2), (3, 4)) + + α = randn(T) + β = randn(T) + + C = _randomize!(TensorOperations.tensoralloc_contract(T, pAB, A, pA, :N, + B, pB, :N, false)) + test_rrule(planarcontract!, C, A, pA, B, pB, pAB, α, β; atol, rtol) + end + end + @testset "Factorizations with scalartype $T" for T in (Float64, ComplexF64) A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) B = TensorMap(randn, T, space(A)') From e405b61ca374399efc6a9e1511e6162d97fa547a Mon Sep 17 00:00:00 2001 From: lkdvos Date: Tue, 30 Jan 2024 13:08:47 +0100 Subject: [PATCH 10/29] import `getindices` --- src/fusiontrees/manipulations.jl | 8 ++++---- src/tensors/indexmanipulations.jl | 4 ++-- src/tensors/tensor.jl | 4 ++-- src/tensors/tensoroperations.jl | 20 ++++++++++---------- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/fusiontrees/manipulations.jl b/src/fusiontrees/manipulations.jl index cea269fad..98da4739c 100644 --- a/src/fusiontrees/manipulations.jl +++ b/src/fusiontrees/manipulations.jl @@ -580,11 +580,11 @@ function planar_trace(f₁::FusionTree{I}, f₂::FusionTree{I}, linearindex = (ntuple(identity, Val(length(f₁)))..., reverse(length(f₁) .+ ntuple(identity, Val(length(f₂))))...) - q1′ = TupleTools.getindices(linearindex, q1) - q2′ = TupleTools.getindices(linearindex, q2) + q1′ = getindices(linearindex, q1) + q2′ = getindices(linearindex, q2) p1′, p2′ = let q′ = (q1′..., q2′...) - (map(l -> l - count(l .> q′), TupleTools.getindices(linearindex, p1)), - map(l -> l - count(l .> q′), TupleTools.getindices(linearindex, p2))) + (map(l -> l - count(l .> q′), getindices(linearindex, p1)), + map(l -> l - count(l .> q′), getindices(linearindex, p2))) end u = one(I) diff --git a/src/tensors/indexmanipulations.jl b/src/tensors/indexmanipulations.jl index 338b701b3..e4b2277b2 100644 --- a/src/tensors/indexmanipulations.jl +++ b/src/tensors/indexmanipulations.jl @@ -248,8 +248,8 @@ end length(levels) == numind(tsrc) || throw(ArgumentError("incorrect levels $levels for tensor map $(codomain(tsrc)) ← $(domain(tsrc))")) - levels1 = TupleTools.getindices(levels, codomainind(tsrc)) - levels2 = TupleTools.getindices(levels, domainind(tsrc)) + levels1 = getindices(levels, codomainind(tsrc)) + levels2 = getindices(levels, domainind(tsrc)) # TODO: arg order for tensormaps is different than for fusiontrees treebraider(f₁, f₂) = braid(f₁, f₂, levels1, levels2, p...) return add_transform!(tdst, tsrc, p, treebraider, α, β, backend...) diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index 5c8158677..73d296358 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -437,8 +437,8 @@ fusiontrees(t::TensorMap) = TensorKeyIterator(t.rowr, t.colr) sectors::Tuple{Vararg{I}}) where {N₁,N₂,I<:Sector} FusionStyle(I) isa UniqueFusion || throw(SectorMismatch("Indexing with sectors only possible if unique fusion")) - s1 = TupleTools.getindices(sectors, codomainind(t)) - s2 = map(dual, TupleTools.getindices(sectors, domainind(t))) + s1 = getindices(sectors, codomainind(t)) + s2 = map(dual, getindices(sectors, domainind(t))) c1 = length(s1) == 0 ? one(I) : (length(s1) == 1 ? s1[1] : first(⊗(s1...))) @boundscheck begin c2 = length(s2) == 0 ? one(I) : (length(s2) == 1 ? s2[1] : first(⊗(s2...))) diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index 9f8ed5a91..8b7c1c477 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -28,8 +28,8 @@ function _canonicalize(p::Index2Tuple{N₁,N₂}, end _canonicalize(p::Index2Tuple, t::AbstractTensorMap) = _canonicalize(linearize(p), t) function _canonicalize(p::IndexTuple, t::AbstractTensorMap) - p₁ = TupleTools.getindices(p, codomainind(t)) - p₂ = TupleTools.getindices(p, domainind(t)) + p₁ = getindices(p, codomainind(t)) + p₂ = getindices(p, domainind(t)) return (p₁, p₂) end @@ -230,16 +230,16 @@ function contract!(C::AbstractTensorMap{S}, # find optimal contraction scheme hsp = has_shared_permute ipC = TupleTools.invperm((p₁..., p₂...)) - oindAinC = TupleTools.getindices(ipC, ntuple(n -> n, N₁)) - oindBinC = TupleTools.getindices(ipC, ntuple(n -> n + N₁, N₂)) + oindAinC = getindices(ipC, ntuple(n -> n, N₁)) + oindBinC = getindices(ipC, ntuple(n -> n + N₁, N₂)) qA = TupleTools.sortperm(cindA) - cindA′ = TupleTools.getindices(cindA, qA) - cindB′ = TupleTools.getindices(cindB, qA) + cindA′ = getindices(cindA, qA) + cindB′ = getindices(cindB, qA) qB = TupleTools.sortperm(cindB) - cindA′′ = TupleTools.getindices(cindA, qB) - cindB′′ = TupleTools.getindices(cindB, qB) + cindA′′ = getindices(cindA, qB) + cindB′′ = getindices(cindB, qB) dA, dB, dC = dim(A), dim(B), dim(C) @@ -301,8 +301,8 @@ function _contract!(α, A::AbstractTensorMap{S}, B::AbstractTensorMap{S}, end end ipC = TupleTools.invperm((p₁..., p₂...)) - oindAinC = TupleTools.getindices(ipC, ntuple(n -> n, N₁)) - oindBinC = TupleTools.getindices(ipC, ntuple(n -> n + N₁, N₂)) + oindAinC = getindices(ipC, ntuple(n -> n, N₁)) + oindBinC = getindices(ipC, ntuple(n -> n + N₁, N₂)) if has_shared_permute(C, (oindAinC, oindBinC)) C′ = permute(C, (oindAinC, oindBinC)) mul!(C′, A′, B′, α, β) From b42d11920bb4b9a9b50569644e29396e08e96185 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Tue, 14 May 2024 15:46:25 +0200 Subject: [PATCH 11/29] Add planarcontract_indices --- src/TensorKit.jl | 1 + src/planar/indices.jl | 101 ++++++++++++++++++++++++++++++++++++++++++ test/planar.jl | 22 ++++++++- 3 files changed, 122 insertions(+), 2 deletions(-) create mode 100644 src/planar/indices.jl diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 5aaa68924..b3169e4ec 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -195,6 +195,7 @@ include("tensors/braidingtensor.jl") # #----------------------------------------- @nospecialize using Base.Meta: isexpr +include("planar/indices.jl") include("planar/analyzers.jl") include("planar/preprocessors.jl") include("planar/postprocessors.jl") diff --git a/src/planar/indices.jl b/src/planar/indices.jl new file mode 100644 index 000000000..f1decfc50 --- /dev/null +++ b/src/planar/indices.jl @@ -0,0 +1,101 @@ +""" + planarcontract_indices(IA, IB, IC) + +Convert a set of tensor labels to a set of indices. Throws an error if this cannot be achieved in a planar manner. +""" +function planarcontract_indices(IA::Tuple{NTuple{NA1},NTuple{NA2}}, + IB::Tuple{NTuple{NB1},NTuple{NB2}}, + IC::Tuple{NTuple{NC1},NTuple{NC2}}) where {NA1,NA2,NB1,NB2, + NC1,NC2} + IA_linear = (IA[1]..., reverse(IA[2])...) + IB_linear = (IB[1]..., reverse(IB[2])...) + IC_linear = (IC[1]..., reverse(IC[2])...) + IAB = (IA_linear..., IB_linear...) + isodd(length(IAB) - length(IC_linear)) && + throw(IndexError("invalid contraction pattern: $IA and $IB to $IC")) + + Icontract = TO.tunique(TO.tsetdiff(IAB, IC_linear)) + IopenA = TO.tsetdiff(IA_linear, Icontract) + IopenB = TO.tsetdiff(IB_linear, Icontract) + + # bring IA to the form (IopenA..., Icontract...) (as sets) + IA′ = IA_linear + ctr = 0 + while !issetequal(getindices(IA′, ntuple(identity, length(IopenA))), IopenA) + IA′ = _cyclicpermute(IA′) + ctr += 1 + ctr > length(IA′) && + throw(ArgumentError("no cyclic permutation of $IA that matches $IB")) + end + + # bring IB to the form (Icontract..., IopenB...) (as sets) + IB′ = IB_linear + ctr = 0 + while !issetequal(getindices(IB′, ntuple(i -> i + length(Icontract), length(IopenB))), + IopenB) + IB′ = _cyclicpermute(IB′) + ctr += 1 + ctr > length(IB′) && + throw(ArgumentError("no cyclic permutation of $IB that matches $IA")) + end + + # special case when IopenA is empty -> still have freedom to circshift IA + if length(IopenA) == 0 + ctr = 0 + while !isequal(IA′, reverse(getindices(IB′, ntuple(identity, length(IA′))))) + IA′ = _cyclicpermute(IA′) + ctr += 1 + ctr > length(IA′) && + throw(ArgumentError("no cyclic permutation of $IA that matches $IB")) + end + end + + # special case when IopenB is empty -> still have freedom to circshift IB + if length(IopenB) == 0 + ctr = 0 + while !isequal(IB′, + reverse(getindices(IA′, + ntuple(i -> i + length(IopenA), length(IB′))))) + IB′ = _cyclicpermute(IB′) + ctr += 1 + ctr > length(IB′) && + throw(ArgumentError("no cyclic permutation of $IB that matches $IA")) + end + end + + # bring IC to the form (IopenA..., IopenB...) (as sets) + IC′ = IC_linear + IopenA + ctr = 0 + while !issetequal(getindices(IC′, ntuple(identity, length(IopenA))), IopenA) + IC′ = _cyclicpermute(IC′) + ctr += 1 + ctr > length(IC′) && + throw(ArgumentError("no cyclic permutation of $IC that matches $IA and $IB")) + end + + # special case when Icontract is empty -> still have freedom to circshift IA and IB to + # match IC + # TODO: this is not yet implemented + @assert length(Icontract) != 0 "not yet implemented" + + IA_nonlinear = (IA[1]..., IA[2]...) + pA = (_indexin(getindices(IA′, ntuple(identity, length(IopenA))), IA_nonlinear), + reverse(_indexin(getindices(IA′, + ntuple(i -> i + length(IopenA), length(Icontract))), + IA_nonlinear))) + + IB_nonlinear = (IB[1]..., IB[2]...) + pB = (_indexin(getindices(IB′, ntuple(identity, length(Icontract))), IB_nonlinear), + reverse(_indexin(getindices(IB′, + ntuple(i -> i + length(Icontract), length(IopenB))), + IB_nonlinear))) + + IC″ = (ntuple(i -> IC′[i], length(IopenA))..., + ntuple(i -> IC′[end + 1 - i], length(IopenB))...) + invIC = _indexin(IC_linear, IC″) + pC = (ntuple(i -> invIC[i], NC1), + ntuple(i -> invIC[end + 1 - i], NC2)) + + return pA, pB, pC +end diff --git a/test/planar.jl b/test/planar.jl index 5b3ae9868..d34c91d47 100644 --- a/test/planar.jl +++ b/test/planar.jl @@ -27,6 +27,24 @@ function force_planar(tsrc::TensorMap{<:GradedSpace}) return tdst end +using TensorKit: planarcontract_indices + +@testset "planarcontract_indices" begin + IA = ((:a, :b, :c), (:d,)) + IB = ((:d,), (:e, :f)) + IC = ((:a, :b, :c), (:e, :f)) + pA, pB, pC = planarcontract_indices(IA, IB, IC) + @test pA == ((1, 2, 3), (4,)) + @test pB == ((1,), (2, 3)) + @test pC == ((1, 2, 3), (4, 5)) + + IC = ((:a, :b, :c, :f, :e), ()) + pA, pB, pC = planarcontract_indices(IA, IB, IC) + @test pA == ((1, 2, 3), (4,)) + @test pB == ((1,), (2, 3)) + @test pC == ((1, 2, 3, 5, 4), ()) +end + using TensorKit: reorder_planar_indices @testset "reorder_indices" begin @testset "trivial case" begin @@ -204,7 +222,7 @@ using TensorKit: reorder_planar_indices @test pB == pB′ @test pAB′ == ((1, 2, 3), (4, 5, 6)) end - + @testset "edge case with no uncontracted indices right" begin pA = ((), (1, 2, 3)) pB = ((2, 1, 3), ()) @@ -217,7 +235,7 @@ using TensorKit: reorder_planar_indices @test pB′ == ((1, 3, 2), ()) @test pAB′ == ((), ()) end - + @testset "something" begin pA, pB, pAB = (((1, 2, 3), (4, 5)), ((3, 1), (4, 2)), ((3, 5), (2, 1, 4))) indA = ((1, 2, 3), (5, 4)) From ccb9ede650fdbfa41c164b7fd90e033ae35efd4b Mon Sep 17 00:00:00 2001 From: lkdvos Date: Tue, 14 May 2024 17:08:49 +0200 Subject: [PATCH 12/29] Move planar index functions to separate file --- src/fusiontrees/manipulations.jl | 4 ++ src/planar/indices.jl | 92 +++++++++++++++++++++++++++++++ src/planar/planaroperations.jl | 93 -------------------------------- test/planar.jl | 14 +++++ 4 files changed, 110 insertions(+), 93 deletions(-) diff --git a/src/fusiontrees/manipulations.jl b/src/fusiontrees/manipulations.jl index 98da4739c..ef1495e32 100644 --- a/src/fusiontrees/manipulations.jl +++ b/src/fusiontrees/manipulations.jl @@ -370,6 +370,10 @@ function iscyclicpermutation(v1, v2) length(v1) == length(v2) || return false return iscyclicpermutation(indexin(v1, v2)) end +function iscyclicpermutation(v1::Tuple, v2::Tuple) + length(v1) == length(v2) || return false + return iscyclicpermutation(TupleTools.indexin(v1, v2)) +end # clockwise cyclic permutation while preserving (N₁, N₂): foldright & bendleft function cycleclockwise(f₁::FusionTree{I}, f₂::FusionTree{I}) where {I<:Sector} diff --git a/src/planar/indices.jl b/src/planar/indices.jl index f1decfc50..6fcc72c1e 100644 --- a/src/planar/indices.jl +++ b/src/planar/indices.jl @@ -99,3 +99,95 @@ function planarcontract_indices(IA::Tuple{NTuple{NA1},NTuple{NA2}}, return pA, pB, pC end + +function reorder_planar_indices(indA, pA, indB, pB, pAB) + NA₁ = length(pA[1]) + NA₂ = length(pA[2]) + NA = NA₁ + NA₂ + NB₁ = length(pB[1]) + NB₂ = length(pB[2]) + NB = NB₁ + NB₂ + NAB₁ = length(pAB[1]) + NAB₂ = length(pAB[2]) + NAB = NAB₁ + NAB₂ + + # input checks + @assert NA == length(indA[1]) + length(indA[2]) + @assert NB == length(indB[1]) + length(indB[2]) + @assert NA₂ == NB₁ + @assert NAB == NA₁ + NB₂ + + # find circshift index of pAB if considered as shifting sets + indAB = (ntuple(identity, NA₁), reverse(ntuple(n -> n + NA₁, NB₂))) + + if NAB > 0 + indAB_lin = (indAB[1]..., indAB[2]...) + iAB = _findsetcircshift(indAB_lin, pAB[1]) + @assert iAB == mod(_findsetcircshift(indAB_lin, pAB[2]) - NAB₁, NAB) "sanity check" + indAB_lin = _circshift(indAB_lin, -iAB) + # migrate permutations from pAB to pA and pB + + pAB_lin = (pAB[1]..., reverse(pAB[2])...) + permA = getindices(pAB_lin, + ntuple(n -> mod1(n - iAB, NAB), NA₁)) + permB = reverse(getindices(pAB_lin, + ntuple(n -> mod1(n - iAB + NA₁, NAB), NB₂)) .- NA₁) + + pA′ = (getindices(pA[1], permA), pA[2]) + pB′ = (pB[1], getindices(pB[2], permB)) + pAB′ = (getindices(indAB_lin, ntuple(n -> mod1(n, NAB), NAB₁)), + reverse(getindices(indAB_lin, ntuple(n -> mod1(n + NAB₁, NAB), NAB₂)))) + else + pA′ = pA + pB′ = pB + pAB′ = pAB + end + + # cycle indA to be of the form (oindA..., reverse(cindA)...) + indA_lin = (indA[1]..., indA[2]...) + if NA₁ != 0 + iA = findfirst(==(first(pA′[1])), indA_lin) - 1 + indA_lin = _circshift(indA_lin, -iA) + end + pc = ntuple(identity, NA₂) + @assert all(getindices(indA_lin, ntuple(identity, NA₁)) .== pA′[1]) "sanity check" + pA′ = (pA′[1], reverse(getindices(indA_lin, pc .+ NA₁))) + + # cycle indB to be of the form (cindB..., reverse(oindB)...) + indB_lin = (indB[1]..., indB[2]...) + if NB₂ != 0 + iB = findfirst(==(first(pB′[2])), indB_lin) + indB_lin = _circshift(indB_lin, -iB) + end + @assert all(getindices(indB_lin, ntuple(identity, NB₂) .+ NB₁) .== reverse(pB′[2])) "sanity check" + pB′ = (getindices(indB_lin, pc), pB′[2]) + + # if uncontracted indices are empty, we can still make cyclic adjustments + if NA₁ == 0 && NA₂ != 0 + hit = pA[2][findfirst(==(first(pB′[1])), pB[1])] + shiftA = findfirst(==(hit), pA′[2]) - 1 + pA′ = (pA′[1], _circshift(pA′[2], -shiftA)) + end + + if NB₂ == 0 && NB₁ != 0 + hit = pB[1][findfirst(==(first(pA′[2])), pA[2])] + shiftB = findfirst(==(hit), pB′[1]) - 1 + pB′ = (_circshift(pB′[1], -shiftB), pB′[2]) + end + + # make sure this is still the same contraction + @assert issetequal(pA[1], pA′[1]) && issetequal(pA[2], pA′[2]) + @assert issetequal(pB[1], pB′[1]) && issetequal(pB[2], pB′[2]) + # @assert issetequal(pAB[1], pAB′[1]) && issetequal(pAB[2], pAB′[2]) "pAB = $pAB, pAB′ = $pAB′" + @assert issetequal(tuple.(pA[2], pB[1]), tuple.(pA′[2], pB′[1])) "$pA $pB $pA′ $pB′" + + # make sure that everything is now planar + @assert _iscyclicpermutation((indA[1]..., (indA[2])...), + (pA′[1]..., reverse(pA′[2])...)) "indA = $indA, pA′ = $pA′" + @assert _iscyclicpermutation((indB[1]..., (indB[2])...), + (pB′[1]..., reverse(pB′[2])...)) "indB = $indB, pB′ = $pB′" + @assert _iscyclicpermutation((indAB[1]..., (indAB[2])...), + (pAB′[1]..., reverse(pAB′[2])...)) "indAB = $indAB, pAB′ = $pAB′" + + return pA′, pB′, pAB′ +end diff --git a/src/planar/planaroperations.jl b/src/planar/planaroperations.jl index 8c8760127..8b226265a 100644 --- a/src/planar/planaroperations.jl +++ b/src/planar/planaroperations.jl @@ -84,7 +84,6 @@ function planarcontract!(C::AbstractTensorMap{S}, mul!(C′, A′, B′, α, β) else C′ = A′ * B′ - add_transpose!(C, C′, pAB′, α, β) end @@ -129,95 +128,3 @@ function _findsetcircshift(p_cyclic, p_subset) throw(ArgumentError("no cyclic permutation of $p_cyclic that matches $p_subset")) return i - 1::Int end - -function reorder_planar_indices(indA, pA, indB, pB, pAB) - NA₁ = length(pA[1]) - NA₂ = length(pA[2]) - NA = NA₁ + NA₂ - NB₁ = length(pB[1]) - NB₂ = length(pB[2]) - NB = NB₁ + NB₂ - NAB₁ = length(pAB[1]) - NAB₂ = length(pAB[2]) - NAB = NAB₁ + NAB₂ - - # input checks - @assert NA == length(indA[1]) + length(indA[2]) - @assert NB == length(indB[1]) + length(indB[2]) - @assert NA₂ == NB₁ - @assert NAB == NA₁ + NB₂ - - # find circshift index of pAB if considered as shifting sets - indAB = (ntuple(identity, NA₁), reverse(ntuple(n -> n + NA₁, NB₂))) - - if NAB > 0 - indAB_lin = (indAB[1]..., indAB[2]...) - iAB = _findsetcircshift(indAB_lin, pAB[1]) - @assert iAB == mod(_findsetcircshift(indAB_lin, pAB[2]) - NAB₁, NAB) "sanity check" - indAB_lin = _circshift(indAB_lin, -iAB) - # migrate permutations from pAB to pA and pB - - pAB_lin = (pAB[1]..., reverse(pAB[2])...) - permA = getindices(pAB_lin, - ntuple(n -> mod1(n - iAB, NAB), NA₁)) - permB = reverse(getindices(pAB_lin, - ntuple(n -> mod1(n - iAB + NA₁, NAB), NB₂)) .- NA₁) - - pA′ = (getindices(pA[1], permA), pA[2]) - pB′ = (pB[1], getindices(pB[2], permB)) - pAB′ = (getindices(indAB_lin, ntuple(n -> mod1(n, NAB), NAB₁)), - reverse(getindices(indAB_lin, ntuple(n -> mod1(n + NAB₁, NAB), NAB₂)))) - else - pA′ = pA - pB′ = pB - pAB′ = pAB - end - - # cycle indA to be of the form (oindA..., reverse(cindA)...) - indA_lin = (indA[1]..., indA[2]...) - if NA₁ != 0 - iA = findfirst(==(first(pA′[1])), indA_lin) - 1 - indA_lin = _circshift(indA_lin, -iA) - end - pc = ntuple(identity, NA₂) - @assert all(getindices(indA_lin, ntuple(identity, NA₁)) .== pA′[1]) "sanity check" - pA′ = (pA′[1], reverse(getindices(indA_lin, pc .+ NA₁))) - - # cycle indB to be of the form (cindB..., reverse(oindB)...) - indB_lin = (indB[1]..., indB[2]...) - if NB₂ != 0 - iB = findfirst(==(first(pB′[2])), indB_lin) - indB_lin = _circshift(indB_lin, -iB) - end - @assert all(getindices(indB_lin, ntuple(identity, NB₂) .+ NB₁) .== reverse(pB′[2])) "sanity check" - pB′ = (getindices(indB_lin, pc), pB′[2]) - - # if uncontracted indices are empty, we can still make cyclic adjustments - if NA₁ == 0 && NA₂ != 0 - hit = pA[2][findfirst(==(first(pB′[1])), pB[1])] - shiftA = findfirst(==(hit), pA′[2]) - 1 - pA′ = (pA′[1], _circshift(pA′[2], -shiftA)) - end - - if NB₂ == 0 && NB₁ != 0 - hit = pB[1][findfirst(==(first(pA′[2])), pA[2])] - shiftB = findfirst(==(hit), pB′[1]) - 1 - pB′ = (_circshift(pB′[1], -shiftB), pB′[2]) - end - - # make sure this is still the same contraction - @assert issetequal(pA[1], pA′[1]) && issetequal(pA[2], pA′[2]) - @assert issetequal(pB[1], pB′[1]) && issetequal(pB[2], pB′[2]) - # @assert issetequal(pAB[1], pAB′[1]) && issetequal(pAB[2], pAB′[2]) "pAB = $pAB, pAB′ = $pAB′" - @assert issetequal(tuple.(pA[2], pB[1]), tuple.(pA′[2], pB′[1])) "$pA $pB $pA′ $pB′" - - # make sure that everything is now planar - @assert _iscyclicpermutation((indA[1]..., (indA[2])...), - (pA′[1]..., reverse(pA′[2])...)) "indA = $indA, pA′ = $pA′" - @assert _iscyclicpermutation((indB[1]..., (indB[2])...), - (pB′[1]..., reverse(pB′[2])...)) "indB = $indB, pB′ = $pB′" - @assert _iscyclicpermutation((indAB[1]..., (indAB[2])...), - (pAB′[1]..., reverse(pAB′[2])...)) "indAB = $indAB, pAB′ = $pAB′" - - return pA′, pB′, pAB′ -end diff --git a/test/planar.jl b/test/planar.jl index d34c91d47..44cbd4961 100644 --- a/test/planar.jl +++ b/test/planar.jl @@ -43,6 +43,20 @@ using TensorKit: planarcontract_indices @test pA == ((1, 2, 3), (4,)) @test pB == ((1,), (2, 3)) @test pC == ((1, 2, 3, 5, 4), ()) + + IC = ((:c, :f), (:b, :a, :e)) + pA, pB, pC = planarcontract_indices(IA, IB, IC) + @test pA == ((1, 2, 3), (4,)) + @test pB == ((1,), (2, 3)) + @test pC == ((3, 5), (2, 1, 4)) + + IA = ((:a, :b, :e), (:d, :c)) + IB = ((:c,), (:e, :f, :g)) + IC = ((:d, :a, :b), (:f, :g)) + pA, pB, pC = planarcontract_indices(IA, IB, IC) + @test pA == ((4, 1, 2), (5, 3)) + @test pB == ((2, 1), (3, 4)) + @test pC == ((1, 2, 3), (4, 5)) end using TensorKit: reorder_planar_indices From a168fe9ed7af50040b72501dafddd04ee7a20299 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Tue, 14 May 2024 17:09:20 +0200 Subject: [PATCH 13/29] Add planarcontract with conj flags --- src/planar/planaroperations.jl | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/planar/planaroperations.jl b/src/planar/planaroperations.jl index 8b226265a..a5ff2bcb9 100644 --- a/src/planar/planaroperations.jl +++ b/src/planar/planaroperations.jl @@ -44,6 +44,33 @@ function planartrace!(C::AbstractTensorMap{S,N₁,N₂}, return C end +function planarcontract!(C::AbstractTensorMap, + A::AbstractTensorMap, pA::Index2Tuple, conjA::Symbol, + B::AbstractTensorMap, pB::Index2Tuple, conjB::Symbol, + pAB::Index2Tuple, + α::Number, β::Number, backend::Backend...) + # get rid of conj arguments by going to adjoint tensormaps + if conjA == :N + A′ = A + pA′ = pA + elseif conjA == :C + A′ = A' + pA′ = adjointtensorindices(A, pA) + else + throw(ArgumentError("unknown conjugation flag $conjA")) + end + if conjB == :N + B′ = B + pB′ = pB + elseif conjB == :C + B′ = B' + pB′ = adjointtensorindices(B, pB) + else + throw(ArgumentError("unknown conjugation flag $conjB")) + end + + return planarcontract!(C, A, pA′, B, pB′, pAB, α, β, backend...) +end function planarcontract!(C::AbstractTensorMap{S}, A::AbstractTensorMap{S}, pA::Index2Tuple{N₁,N₃}, From fdad8f7195f7b9e06426be084f8e1781bfb22c4a Mon Sep 17 00:00:00 2001 From: lkdvos Date: Tue, 14 May 2024 17:09:45 +0200 Subject: [PATCH 14/29] Add some extra checks in planarcontract implementation --- src/planar/planaroperations.jl | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/planar/planaroperations.jl b/src/planar/planaroperations.jl index a5ff2bcb9..38edeb477 100644 --- a/src/planar/planaroperations.jl +++ b/src/planar/planaroperations.jl @@ -71,6 +71,12 @@ function planarcontract!(C::AbstractTensorMap, return planarcontract!(C, A, pA′, B, pB′, pAB, α, β, backend...) end + +function _isplanar(inds::Index2Tuple, p::Index2Tuple) + return iscyclicpermutation((inds[1]..., inds[2]...), + (p[1]..., reverse(p[2])...)) +end + function planarcontract!(C::AbstractTensorMap{S}, A::AbstractTensorMap{S}, pA::Index2Tuple{N₁,N₃}, @@ -80,14 +86,19 @@ function planarcontract!(C::AbstractTensorMap{S}, α::Number, β::Number, backend::Backend...) where {S,N₁,N₂,N₃} - if BraidingStyle(sectortype(S)) == Bosonic() - return contract!(C, A, pA, B, pB, pAB, α, β, backend...) - end - indA = (codomainind(A), reverse(domainind(A))) indB = (codomainind(B), reverse(domainind(B))) + indC = (codomainind(C), reverse(domainind(C))) pA′, pB′, pAB′ = reorder_planar_indices(indA, pA, indB, pB, pAB) - + + @assert _isplanar(indA, pA′) "not a planar contraction (indA = $indA, pA′ = $pA′)" + @assert _isplanar(indB, pB′) "not a planar contraction (pB′ = $pB′)" + @assert _isplanar(indC, pAB′) "not a planar contraction (pAB′ = $pAB′)" + + if BraidingStyle(sectortype(spacetype(C))) == Bosonic() + return contract!(C, A, pA′, B, pB′, pAB′, α, β, backend...) + end + if pA′ == (codomainind(A), domainind(A)) A′ = A else From fccb958ec42a26156ebdee741c07225b53b72e8b Mon Sep 17 00:00:00 2001 From: lkdvos Date: Tue, 14 May 2024 17:10:50 +0200 Subject: [PATCH 15/29] Also move auxiliary index functions --- src/planar/indices.jl | 27 +++++++++++++++++++++++++++ src/planar/planaroperations.jl | 27 --------------------------- 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/src/planar/indices.jl b/src/planar/indices.jl index 6fcc72c1e..1ae93b56f 100644 --- a/src/planar/indices.jl +++ b/src/planar/indices.jl @@ -191,3 +191,30 @@ function reorder_planar_indices(indA, pA, indB, pB, pAB) return pA′, pB′, pAB′ end + +# auxiliary routines +_cyclicpermute(t::Tuple) = (Base.tail(t)..., t[1]) +_cyclicpermute(t::Tuple{}) = () + +_circshift(::Tuple{}, ::Int) = () +_circshift(t::Tuple, n::Int) = ntuple(i -> t[mod1(i - n, length(t))], length(t)) + +_indexin(v1, v2) = ntuple(n -> findfirst(isequal(v1[n]), v2), length(v1)) + +function _iscyclicpermutation(v1, v2) + length(v1) == length(v2) || return false + return iscyclicpermutation(_indexin(v1, v2)) +end + +function _findsetcircshift(p_cyclic, p_subset) + N = length(p_cyclic) + M = length(p_subset) + N == M == 0 && return 0 + i = findfirst(0:(N - 1)) do i + return issetequal(getindices(p_cyclic, ntuple(n -> mod1(n + i, N), M)), + p_subset) + end + isnothing(i) && + throw(ArgumentError("no cyclic permutation of $p_cyclic that matches $p_subset")) + return i - 1::Int +end diff --git a/src/planar/planaroperations.jl b/src/planar/planaroperations.jl index 38edeb477..5964baa40 100644 --- a/src/planar/planaroperations.jl +++ b/src/planar/planaroperations.jl @@ -139,30 +139,3 @@ function planarcontract(A::AbstractTensorMap{S}, pA::Index2Tuple, C = TO.tensoralloc_contract(TC, pAB, A, pA, :N, B, pB, :N) return planarcontract!(C, A, pA, B, pB, pAB, α, VectorInterface.Zero(), backend...) end - -# auxiliary routines -_cyclicpermute(t::Tuple) = (Base.tail(t)..., t[1]) -_cyclicpermute(t::Tuple{}) = () - -_circshift(::Tuple{}, ::Int) = () -_circshift(t::Tuple, n::Int) = ntuple(i -> t[mod1(i - n, length(t))], length(t)) - -_indexin(v1, v2) = ntuple(n -> findfirst(isequal(v1[n]), v2), length(v1)) - -function _iscyclicpermutation(v1, v2) - length(v1) == length(v2) || return false - return iscyclicpermutation(_indexin(v1, v2)) -end - -function _findsetcircshift(p_cyclic, p_subset) - N = length(p_cyclic) - M = length(p_subset) - N == M == 0 && return 0 - i = findfirst(0:(N - 1)) do i - return issetequal(getindices(p_cyclic, ntuple(n -> mod1(n + i, N), M)), - p_subset) - end - isnothing(i) && - throw(ArgumentError("no cyclic permutation of $p_cyclic that matches $p_subset")) - return i - 1::Int -end From 7eff5e336c3694404214e610afa5aacc1c7b2e61 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Tue, 14 May 2024 17:39:23 +0200 Subject: [PATCH 16/29] Add planarcontract function --- src/TensorKit.jl | 1 + src/planar/functions.jl | 59 ++++++++++++++++++++++++++++++++++ src/planar/planaroperations.jl | 9 ------ 3 files changed, 60 insertions(+), 9 deletions(-) create mode 100644 src/planar/functions.jl diff --git a/src/TensorKit.jl b/src/TensorKit.jl index b3169e4ec..8b3bb4ab5 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -202,6 +202,7 @@ include("planar/postprocessors.jl") include("planar/macros.jl") @specialize include("planar/planaroperations.jl") +include("planar/functions.jl") # deprecations: to be removed in version 1.0 or sooner include("auxiliary/deprecate.jl") diff --git a/src/planar/functions.jl b/src/planar/functions.jl new file mode 100644 index 000000000..afe9e8023 --- /dev/null +++ b/src/planar/functions.jl @@ -0,0 +1,59 @@ +# methods/simple.jl +# +# Method-based access to planar operations using simple definitions. + +# ------------------------------------------------------------------------------------------ + +""" + planarcontract(A, IA, [conjA], B, IB, [conjB], [IC], [α=1]) + planarcontract(A, pA::Index2Tuple, conjA, B, pB::Index2Tuple, conjB, pAB::Index2Tuple, α=1, [backend]) # expert mode + +Contract indices of tensor `A` with corresponding indices in tensor `B` by assigning +them identical labels in the iterables `IA` and `IB`. The indices of the resulting +tensor correspond to the indices that only appear in either `IA` or `IB` and can be +ordered by specifying the optional argument `IC`. The default is to have all open +indices of `A` followed by all open indices of `B`. Note that inner contractions of an array +should be handled first with `tensortrace`, so that every label can appear only once in `IA` +or `IB` seperately, and once (for an open index) or twice (for a contracted index) in the +union of `IA` and `IB`. + +Optionally, the symbols `conjA` and `conjB` can be used to specify that the input tensors +should be conjugated. + +See also [`tensorcontract`](@ref). +""" +function planarcontract end + +const Tuple2 = Tuple{Tuple, Tuple} + +function planarcontract(A, IA::Tuple2, conjA::Symbol, B, IB::Tuple2, conjB::Symbol, IC::Tuple2, + α::Number=One()) + @assert length(IA[1]) == numout(A) && length(IA[2]) == numin(A) "invalid IA" + @assert length(IB[1]) == numout(B) && length(IB[2]) == numin(B) "invalid IB" + pA, pB, pAB = planarcontract_indices(IA, IB, IC) + return planarcontract(A, pA, conjA, B, pB, conjB, pAB, α) +end +# default `IC` +function planarcontract(A, IA::Tuple2, conjA::Symbol, B, IB::Tuple2, conjB::Symbol, α::Number=One()) + @assert length(IA[1]) == numout(A) && length(IA[2]) == numin(A) "invalid IA" + @assert length(IB[1]) == numout(B) && length(IB[2]) == numin(B) "invalid IB" + pA, pB, pAB = planarcontract_indices(IA, IB) + return planarcontract(A, pA, conjA, B, pB, conjB, pAB, α) +end +# default `conjA` and `conjB` +function planarcontract(A, IA, B, IB, IC, α::Number=One()) + return planarcontract(A, IA, :N, B, IB, :N, IC, α) +end +function planarcontract(A, IA, B, IB, α::Number=One()) + return planarcontract(A, IA, :N, B, IB, :N, α) +end + +# expert mode +function planarcontract(A, pA::Index2Tuple, conjA::Symbol, + B, pB::Index2Tuple, conjB::Symbol, + pAB::Index2Tuple, α::Number=One(), + backend::Backend...) + TC = promote_contract(scalartype(A), scalartype(B), scalartype(α)) + C = tensoralloc_contract(TC, pAB, A, pA, conjA, B, pB, conjB) + return planarcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, Zero(), backend...) +end diff --git a/src/planar/planaroperations.jl b/src/planar/planaroperations.jl index 5964baa40..b932133d3 100644 --- a/src/planar/planaroperations.jl +++ b/src/planar/planaroperations.jl @@ -130,12 +130,3 @@ function planarcontract!(C::AbstractTensorMap{S}, return C end - -function planarcontract(A::AbstractTensorMap{S}, pA::Index2Tuple, - B::AbstractTensorMap{S}, pB::Index2Tuple, - pAB::Index2Tuple{N₁,N₂}, - α::Number, backend::Backend...) where {S,N₁,N₂} - TC = TO.promote_contract(scalartype(A), scalartype(B), scalartype(α)) - C = TO.tensoralloc_contract(TC, pAB, A, pA, :N, B, pB, :N) - return planarcontract!(C, A, pA, B, pB, pAB, α, VectorInterface.Zero(), backend...) -end From ec18564f008fcfab6bc11c9951408cea939bbe27 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Tue, 14 May 2024 17:40:09 +0200 Subject: [PATCH 17/29] Update planarcontract_indices --- src/planar/indices.jl | 89 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 82 insertions(+), 7 deletions(-) diff --git a/src/planar/indices.jl b/src/planar/indices.jl index 1ae93b56f..11bdfbf6a 100644 --- a/src/planar/indices.jl +++ b/src/planar/indices.jl @@ -1,12 +1,11 @@ """ - planarcontract_indices(IA, IB, IC) + planarcontract_indices(IA, IB, [IC]) Convert a set of tensor labels to a set of indices. Throws an error if this cannot be achieved in a planar manner. """ -function planarcontract_indices(IA::Tuple{NTuple{NA1},NTuple{NA2}}, - IB::Tuple{NTuple{NB1},NTuple{NB2}}, - IC::Tuple{NTuple{NC1},NTuple{NC2}}) where {NA1,NA2,NB1,NB2, - NC1,NC2} +function planarcontract_indices(IA::Tuple{Tuple,Tuple}, + IB::Tuple{Tuple,Tuple}, + IC::Tuple{Tuple,Tuple}) IA_linear = (IA[1]..., reverse(IA[2])...) IB_linear = (IB[1]..., reverse(IB[2])...) IC_linear = (IC[1]..., reverse(IC[2])...) @@ -94,8 +93,84 @@ function planarcontract_indices(IA::Tuple{NTuple{NA1},NTuple{NA2}}, IC″ = (ntuple(i -> IC′[i], length(IopenA))..., ntuple(i -> IC′[end + 1 - i], length(IopenB))...) invIC = _indexin(IC_linear, IC″) - pC = (ntuple(i -> invIC[i], NC1), - ntuple(i -> invIC[end + 1 - i], NC2)) + pC = (ntuple(i -> invIC[i], length(IC[1])), + ntuple(i -> invIC[end + 1 - i], length(IC[2]))) + + return pA, pB, pC +end +function planarcontract_indices(IA::Tuple{Tuple,Tuple}, + IB::Tuple{Tuple,Tuple}) + IA_linear = (IA[1]..., reverse(IA[2])...) + IB_linear = (IB[1]..., reverse(IB[2])...) + IAB = (IA_linear..., IB_linear...) + + Icontract = TO.tunique(TO.tsetdiff(IAB, TO.unique2(IAB))) + IopenA = TO.tsetdiff(IA_linear, Icontract) + IopenB = TO.tsetdiff(IB_linear, Icontract) + + # bring IA to the form (IopenA..., Icontract...) (as sets) + IA′ = IA_linear + ctr = 0 + while !issetequal(getindices(IA′, ntuple(identity, length(IopenA))), IopenA) + IA′ = _cyclicpermute(IA′) + ctr += 1 + ctr > length(IA′) && + throw(ArgumentError("no cyclic permutation of $IA that matches $IB")) + end + + # bring IB to the form (Icontract..., IopenB...) (as sets) + IB′ = IB_linear + ctr = 0 + while !issetequal(getindices(IB′, ntuple(i -> i + length(Icontract), length(IopenB))), + IopenB) + IB′ = _cyclicpermute(IB′) + ctr += 1 + ctr > length(IB′) && + throw(ArgumentError("no cyclic permutation of $IB that matches $IA")) + end + + # special case when IopenA is empty -> still have freedom to circshift IA + if length(IopenA) == 0 + ctr = 0 + while !isequal(IA′, reverse(getindices(IB′, ntuple(identity, length(IA′))))) + IA′ = _cyclicpermute(IA′) + ctr += 1 + ctr > length(IA′) && + throw(ArgumentError("no cyclic permutation of $IA that matches $IB")) + end + end + + # special case when IopenB is empty -> still have freedom to circshift IB + if length(IopenB) == 0 + ctr = 0 + while !isequal(IB′, + reverse(getindices(IA′, + ntuple(i -> i + length(IopenA), length(IB′))))) + IB′ = _cyclicpermute(IB′) + ctr += 1 + ctr > length(IB′) && + throw(ArgumentError("no cyclic permutation of $IB that matches $IA")) + end + end + + # special case when Icontract is empty -> still have freedom to circshift IA and IB to + # match IC + # TODO: this is not yet implemented + @assert length(Icontract) != 0 "not yet implemented" + + IA_nonlinear = (IA[1]..., IA[2]...) + pA = (_indexin(getindices(IA′, ntuple(identity, length(IopenA))), IA_nonlinear), + reverse(_indexin(getindices(IA′, + ntuple(i -> i + length(IopenA), length(Icontract))), + IA_nonlinear))) + + IB_nonlinear = (IB[1]..., IB[2]...) + pB = (_indexin(getindices(IB′, ntuple(identity, length(Icontract))), IB_nonlinear), + reverse(_indexin(getindices(IB′, + ntuple(i -> i + length(Icontract), length(IopenB))), + IB_nonlinear))) + + pC = (ntuple(identity, length(IopenA)), reverse(ntuple(i -> i + length(IopenA), length(IopenB)))) return pA, pB, pC end From a20c2ebdc61770e77deeb5030d33e11c9c7952dc Mon Sep 17 00:00:00 2001 From: lkdvos Date: Wed, 15 May 2024 19:30:12 +0200 Subject: [PATCH 18/29] Add some more auxiliary methods for tensor indices --- src/tensors/abstracttensor.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/tensors/abstracttensor.jl b/src/tensors/abstracttensor.jl index f69a29701..bd6e7b1da 100644 --- a/src/tensors/abstracttensor.jl +++ b/src/tensors/abstracttensor.jl @@ -75,14 +75,21 @@ codomainind(t::AbstractTensorMap) = codomainind(typeof(t)) domainind(t::AbstractTensorMap) = domainind(typeof(t)) allind(t::AbstractTensorMap) = allind(typeof(t)) +adjointtensorindex((N₁, N₂)::Tuple{Int,Int}, i) = i <= N₁ ? N₂ + i : i - N₁ function adjointtensorindex(::AbstractTensorMap{<:IndexSpace,N₁,N₂}, i) where {N₁,N₂} - return ifelse(i <= N₁, N₂ + i, i - N₁) + return adjointtensorindex((N₁, N₂), i) end +function adjointtensorindices((N₁, N₂)::Tuple{Int,Int}, indices::IndexTuple) + return map(i -> adjointtensorindex((N₁, N₂), i), indices) +end function adjointtensorindices(t::AbstractTensorMap, indices::IndexTuple) return map(i -> adjointtensorindex(t, i), indices) end +function adjointtensorindices((N₁, N₂)::Tuple{Int,Int}, p::Index2Tuple) + return adjointtensorindices((N₁, N₂), p[1]), adjointtensorindices((N₁, N₂), p[2]) +end function adjointtensorindices(t::AbstractTensorMap, p::Index2Tuple) return adjointtensorindices(t, p[1]), adjointtensorindices(t, p[2]) end From c2d7d1314654739de5e4eda712ecd575630c003c Mon Sep 17 00:00:00 2001 From: lkdvos Date: Wed, 15 May 2024 19:30:37 +0200 Subject: [PATCH 19/29] Add planarcopy and planartrace --- src/planar/functions.jl | 42 +++++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/src/planar/functions.jl b/src/planar/functions.jl index afe9e8023..e2d6ec65b 100644 --- a/src/planar/functions.jl +++ b/src/planar/functions.jl @@ -4,6 +4,25 @@ # ------------------------------------------------------------------------------------------ +function planarcopy(A, pA::Index2Tuple, conjA::Symbol, α::Number=One(), backend::Backend...) + TC = TO.promote_add(scalartype(A), scalartype(α)) + C = tensoralloc_add(TC, pA, A, conjA) + return planaradd!(C, A, pA, conjA, α, Zero(), backend...) +end + +# ------------------------------------------------------------------------------------------ + + +function planartrace(A, pA::Index2Tuple, qA::Index2Tuple, conjA::Symbol, α::Number=One(), backend::Backend...) + TC = TO.promote_contract(scalartype(A), scalartype(α)) + C = tensoralloc_add(TC, pA, A, conjA) + return planartrace!(C, A, pA, qA, conjA, α, Zero(), backend...) +end + +# ------------------------------------------------------------------------------------------ + + + """ planarcontract(A, IA, [conjA], B, IB, [conjB], [IC], [α=1]) planarcontract(A, pA::Index2Tuple, conjA, B, pB::Index2Tuple, conjB, pAB::Index2Tuple, α=1, [backend]) # expert mode @@ -24,20 +43,19 @@ See also [`tensorcontract`](@ref). """ function planarcontract end -const Tuple2 = Tuple{Tuple, Tuple} - -function planarcontract(A, IA::Tuple2, conjA::Symbol, B, IB::Tuple2, conjB::Symbol, IC::Tuple2, +function planarcontract(A, IA::TensorLabels, conjA::Symbol, B, IB::TensorLabels, conjB::Symbol, IC::TensorLabels, α::Number=One()) - @assert length(IA[1]) == numout(A) && length(IA[2]) == numin(A) "invalid IA" - @assert length(IB[1]) == numout(B) && length(IB[2]) == numin(B) "invalid IB" - pA, pB, pAB = planarcontract_indices(IA, IB, IC) + ia = canonicalize_labels(A, IA) + ib = canonicalize_labels(B, IB) + ic = canonicalize_labels(IC) + pA, pB, pAB = planarcontract_indices(ia, ib, ic) return planarcontract(A, pA, conjA, B, pB, conjB, pAB, α) end # default `IC` -function planarcontract(A, IA::Tuple2, conjA::Symbol, B, IB::Tuple2, conjB::Symbol, α::Number=One()) - @assert length(IA[1]) == numout(A) && length(IA[2]) == numin(A) "invalid IA" - @assert length(IB[1]) == numout(B) && length(IB[2]) == numin(B) "invalid IB" - pA, pB, pAB = planarcontract_indices(IA, IB) +function planarcontract(A, IA::TensorLabels, conjA::Symbol, B, IB::TensorLabels, conjB::Symbol, α::Number=One()) + ia = canonicalize_labels(A, IA) + ib = canonicalize_labels(B, IB) + pA, pB, pAB = planarcontract_indices(ia, ib) return planarcontract(A, pA, conjA, B, pB, conjB, pAB, α) end # default `conjA` and `conjB` @@ -53,7 +71,7 @@ function planarcontract(A, pA::Index2Tuple, conjA::Symbol, B, pB::Index2Tuple, conjB::Symbol, pAB::Index2Tuple, α::Number=One(), backend::Backend...) - TC = promote_contract(scalartype(A), scalartype(B), scalartype(α)) - C = tensoralloc_contract(TC, pAB, A, pA, conjA, B, pB, conjB) + TC = TO.promote_contract(scalartype(A), scalartype(B), scalartype(α)) + C = TO.tensoralloc_contract(TC, pAB, A, pA, conjA, B, pB, conjB) return planarcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, Zero(), backend...) end From 0f8544eed9465c9e8b6456144800bc0595966207 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Wed, 15 May 2024 19:30:49 +0200 Subject: [PATCH 20/29] more planar index stuff --- src/planar/indices.jl | 93 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 87 insertions(+), 6 deletions(-) diff --git a/src/planar/indices.jl b/src/planar/indices.jl index 11bdfbf6a..51a06aaee 100644 --- a/src/planar/indices.jl +++ b/src/planar/indices.jl @@ -1,8 +1,90 @@ +const TensorLabels = Union{Tuple,Vector} + +function canonicalize_labels(A::AbstractTensorMap, IA::TensorLabels) + numind(A) == length(IA) || + throw(ArgumentError("invalid labels for tensor: $IA for ($(numout(A)), $(numin(A)))")) + return (ntuple(i -> IA[i], numout(A)), ntuple(i -> IA[numout(A) + i], numin(A))) +end +canonicalize_labels(IA::TensorLabels) = (tuple(IA...), ()) + +function _isplanar(inds::Index2Tuple, p::Index2Tuple) + return iscyclicpermutation((inds[1]..., inds[2]...), + (p[1]..., reverse(p[2])...)) +end + +function planartrace_indices(IA::Tuple{Tuple,Tuple}, conjA, IC::Tuple{Tuple,Tuple}) + IA′ = conjA == :C ? reverse(IA) : IA + + IA_linear = (IA′[1]..., (IA′[2])...) + IC_linear = (IC[1]..., (IC[2])...) + + p, q1, q2 = TO.trace_indices(IA_linear, IC_linear) + + if conjA == :C + p′ = adjointtensorindices((length(IC[1]), length(IC[2])), p) + q′ = adjointtensorindices((length(IA[2]), length(IA[1])), (q1, q2)) + else + p′ = p + q′ = (q1, q2) + end + + return p′, q′ +end +function planartrace_indices(IA::Tuple{Tuple,Tuple}, conjA) + IA′ = conjA == :C ? reverse(IA) : IA + + IA_linear = (IA′[1]..., (IA′[2])...) + IC_linear = tuple(TO.unique2(IA_linear)...) + p, q1, q2 = TO.trace_indices(IA_linear, IC_linear) + + if conjA == :C + p′ = adjointtensorindices((length(p[1]), length(p[2])), p) + q′ = adjointtensorindices((length(IA[2]), length(IA[1])), (q1, q2)) + else + p′ = p + q′ = (q1, q2) + end + + return p′, q′ +end + + """ planarcontract_indices(IA, IB, [IC]) -Convert a set of tensor labels to a set of indices. Throws an error if this cannot be achieved in a planar manner. +Convert a set of tensor labels to a set of indices. Throws an error if this cannot be +achieved in a planar manner. """ +function planarcontract_indices(IA::Tuple{Tuple,Tuple}, conjA::Symbol, + IB::Tuple{Tuple,Tuple}, conjB::Symbol, + IC::Tuple{Tuple,Tuple}) + + IA′ = conjA == :C ? reverse(IA) : IA + IB′ = conjB == :C ? reverse(IB) : IB + + pA, pB, pAB = planarcontract_indices(IA′, IB′, IC) + + # map indices back to original tensor + pA′ = conjA == :C ? adjointtensorindices((length(IA[2]), length(IA[1])), pA) : pA + pB′ = conjB == :C ? adjointtensorindices((length(IB[2]), length(IB[1])), pB) : pB + + return pA′, pB′, pAB +end +function planarcontract_indices(IA::Tuple{Tuple,Tuple}, conjA::Symbol, + IB::Tuple{Tuple,Tuple}, conjB::Symbol) + # map indices to indices of adjoint tensor + IA′ = conjA == :C ? reverse(IA) : IA + IB′ = conjB == :C ? reverse(IB) : IB + + pA, pB, pAB = planarcontract_indices(IA′, IB′) + + # map indices back to original tensor + pA′ = conjA == :C ? adjointtensorindices((length(IA[2]), length(IA[1])), pA) : pA + pB′ = conjB == :C ? adjointtensorindices((length(IB[2]), length(IB[1])), pB) : pB + + return pA′, pB′, pAB +end + function planarcontract_indices(IA::Tuple{Tuple,Tuple}, IB::Tuple{Tuple,Tuple}, IC::Tuple{Tuple,Tuple}) @@ -64,7 +146,6 @@ function planarcontract_indices(IA::Tuple{Tuple,Tuple}, # bring IC to the form (IopenA..., IopenB...) (as sets) IC′ = IC_linear - IopenA ctr = 0 while !issetequal(getindices(IC′, ntuple(identity, length(IopenA))), IopenA) IC′ = _cyclicpermute(IC′) @@ -104,7 +185,7 @@ function planarcontract_indices(IA::Tuple{Tuple,Tuple}, IB_linear = (IB[1]..., reverse(IB[2])...) IAB = (IA_linear..., IB_linear...) - Icontract = TO.tunique(TO.tsetdiff(IAB, TO.unique2(IAB))) + Icontract = TO.tunique(TO.tsetdiff(IAB, tuple(TO.unique2(IAB)...))) IopenA = TO.tsetdiff(IA_linear, Icontract) IopenB = TO.tsetdiff(IB_linear, Icontract) @@ -170,7 +251,7 @@ function planarcontract_indices(IA::Tuple{Tuple,Tuple}, ntuple(i -> i + length(Icontract), length(IopenB))), IB_nonlinear))) - pC = (ntuple(identity, length(IopenA)), reverse(ntuple(i -> i + length(IopenA), length(IopenB)))) + pC = (ntuple(identity, length(IopenA)), ntuple(i -> i + length(IopenA), length(IopenB))) return pA, pB, pC end @@ -225,7 +306,7 @@ function reorder_planar_indices(indA, pA, indB, pB, pAB) indA_lin = _circshift(indA_lin, -iA) end pc = ntuple(identity, NA₂) - @assert all(getindices(indA_lin, ntuple(identity, NA₁)) .== pA′[1]) "sanity check" + @assert all(getindices(indA_lin, ntuple(identity, NA₁)) .== pA′[1]) "sanity check: $indA $pA" pA′ = (pA′[1], reverse(getindices(indA_lin, pc .+ NA₁))) # cycle indB to be of the form (cindB..., reverse(oindB)...) @@ -234,7 +315,7 @@ function reorder_planar_indices(indA, pA, indB, pB, pAB) iB = findfirst(==(first(pB′[2])), indB_lin) indB_lin = _circshift(indB_lin, -iB) end - @assert all(getindices(indB_lin, ntuple(identity, NB₂) .+ NB₁) .== reverse(pB′[2])) "sanity check" + @assert all(getindices(indB_lin, ntuple(identity, NB₂) .+ NB₁) .== reverse(pB′[2])) "sanity check: $indB $pB" pB′ = (getindices(indB_lin, pc), pB′[2]) # if uncontracted indices are empty, we can still make cyclic adjustments From b50ad10e711b68c61003342aad1e776e113fdaf6 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Wed, 15 May 2024 19:31:03 +0200 Subject: [PATCH 21/29] clean up planaroperations --- src/planar/planaroperations.jl | 134 +++++++++++++++++++-------------- src/planar/postprocessors.jl | 6 +- 2 files changed, 81 insertions(+), 59 deletions(-) diff --git a/src/planar/planaroperations.jl b/src/planar/planaroperations.jl index b932133d3..8df67b539 100644 --- a/src/planar/planaroperations.jl +++ b/src/planar/planaroperations.jl @@ -1,47 +1,39 @@ -# planar versions of tensor operations add!, trace! and contract! -function planaradd!(C::AbstractTensorMap{S,N₁,N₂}, - A::AbstractTensorMap{S}, - p::Index2Tuple{N₁,N₂}, - α::Number, - β::Number, - backend::Backend...) where {S,N₁,N₂} - return add_transpose!(C, A, p, α, β, backend...) -end +# ---------- +# CONJ FLAGS +# ---------- -function planartrace!(C::AbstractTensorMap{S,N₁,N₂}, - A::AbstractTensorMap{S}, - p::Index2Tuple{N₁,N₂}, - q::Index2Tuple{N₃,N₃}, - α::Number, - β::Number, - backend::Backend...) where {S,N₁,N₂,N₃} - if BraidingStyle(sectortype(S)) == Bosonic() - return trace_permute!(C, A, p, q, α, β, backend...) +function planaradd!(C::AbstractTensorMap{S}, + A::AbstractTensorMap{S}, conjA::Symbol, + p::Index2Tuple, + α::Number, β::Number, backend::Backend...) where {S} + if conjA == :N + A′ = A + p′ = _canonicalize(p, C) + elseif conjA == :C + A′ = adjoint(A) + p′ = adjointtensorindices(A, _canonicalize(p, C)) + else + throw(ArgumentError("unknown conjugation flag $conjA")) end + return add_transpose!(C, A′, p′, α, β, backend...) +end - @boundscheck begin - all(i -> space(A, p[1][i]) == space(C, i), 1:N₁) || - throw(SpaceMismatch("trace: A = $(codomain(A))←$(domain(A)), - C = $(codomain(C))←$(domain(C)), p1 = $(p1), p2 = $(p2)")) - all(i -> space(A, p[2][i]) == space(C, N₁ + i), 1:N₂) || - throw(SpaceMismatch("trace: A = $(codomain(A))←$(domain(A)), - C = $(codomain(C))←$(domain(C)), p1 = $(p1), p2 = $(p2)")) - all(i -> space(A, q[1][i]) == dual(space(A, q[2][i])), 1:N₃) || - throw(SpaceMismatch("trace: A = $(codomain(A))←$(domain(A)), - q1 = $(q1), q2 = $(q2)")) +function planartrace!(C::AbstractTensorMap{S}, p::Index2Tuple, + A::AbstractTensorMap{S}, q::Index2Tuple, conjA::Symbol, + α::Number, β::Number, backend::Backend...) where {S} + if conjA == :N + A′ = A + p′ = _canonicalize(p, C) + q′ = q + elseif conjA == :C + A′ = A' + p′ = adjointtensorindices(A, _canonicalize(p, C)) + q′ = adjointtensorindices(A, q) + else + throw(ArgumentError("unknown conjugation flag $conjA")) end - if iszero(β) - fill!(C, β) - elseif !isone(β) - rmul!(C, β) - end - for (f₁, f₂) in fusiontrees(A) - for ((f₁′, f₂′), coeff) in planar_trace(f₁, f₂, p..., q...) - TO.tensortrace!(C[f₁′, f₂′], p, A[f₁, f₂], q, :N, α * coeff, true, backend...) - end - end - return C + return trace_transpose!(C, A′, p′, q′, α, β, backend...) end function planarcontract!(C::AbstractTensorMap, @@ -49,7 +41,6 @@ function planarcontract!(C::AbstractTensorMap, B::AbstractTensorMap, pB::Index2Tuple, conjB::Symbol, pAB::Index2Tuple, α::Number, β::Number, backend::Backend...) - # get rid of conj arguments by going to adjoint tensormaps if conjA == :N A′ = A pA′ = pA @@ -69,31 +60,62 @@ function planarcontract!(C::AbstractTensorMap, throw(ArgumentError("unknown conjugation flag $conjB")) end - return planarcontract!(C, A, pA′, B, pB′, pAB, α, β, backend...) + return _planarcontract!(C, A′, pA′, B′, pB′, pAB, α, β, backend...) end -function _isplanar(inds::Index2Tuple, p::Index2Tuple) - return iscyclicpermutation((inds[1]..., inds[2]...), - (p[1]..., reverse(p[2])...)) +# --------------- +# IMPLEMENTATIONS +# --------------- + +function trace_transpose!(tdst::AbstractTensorMap{S,N₁,N₂}, + tsrc::AbstractTensorMap{S}, + (p₁, p₂)::Index2Tuple{N₁,N₂}, (q₁, q₂)::Index2Tuple{N₃,N₃}, + α::Number, β::Number, backend::Backend...) where {S,N₁,N₂,N₃} + @boundscheck begin + space(tdst) == permute(space(tsrc), (p₁, p₂)) || + throw(SpaceMismatch("trace: tsrc = $(codomain(tsrc))←$(domain(tsrc)), + tdst = $(codomain(tdst))←$(domain(tdst)), p₁ = $(p₁), p₂ = $(p₂)")) + all(i -> space(tsrc, q₁[i]) == dual(space(tsrc, q₂[i])), 1:N₃) || + throw(SpaceMismatch("trace: tsrc = $(codomain(tsrc))←$(domain(tsrc)), + q₁ = $(q₁), q₂ = $(q₂)")) + # TODO: check planarity? + end + + # TODO: not sure if this is worth it + if BraidingStyle(sectortype(S)) == Bosonic() + return @inbounds trace_permute!(tdst, tsrc, p, q, α, β, backend...) + end + + scale!(tdst, β) + β′ = One() + + for (f₁, f₂) in fusiontrees(tsrc) + @inbounds A = tsrc[f₁, f₂] + for ((f₁′, f₂′), coeff) in planar_trace(f₁, f₂, p₁, p₂, q₁, q₂) + @inbounds C = tdst[f₁′, f₂′] + TO.tensortrace!(C, (p₁, p₂), A, (q₁, q₂), :N, α * coeff, β′, backend...) + end + end + + return tdst end -function planarcontract!(C::AbstractTensorMap{S}, - A::AbstractTensorMap{S}, - pA::Index2Tuple{N₁,N₃}, - B::AbstractTensorMap{S}, - pB::Index2Tuple{N₃,N₂}, - pAB::Index2Tuple, - α::Number, - β::Number, - backend::Backend...) where {S,N₁,N₂,N₃} +# TODO: reuse the same memcost checks as in `contract!` +function _planarcontract!(C::AbstractTensorMap{S}, + A::AbstractTensorMap{S}, pA::Index2Tuple{N₁,N₃}, + B::AbstractTensorMap{S}, pB::Index2Tuple{N₃,N₂}, + pAB::Index2Tuple, + α::Number, β::Number, backend::Backend...) where {S,N₁,N₂,N₃} indA = (codomainind(A), reverse(domainind(A))) indB = (codomainind(B), reverse(domainind(B))) - indC = (codomainind(C), reverse(domainind(C))) + indAB = (ntuple(identity, N₁), reverse(ntuple(i -> i + N₁, N₂))) + + # TODO: avoid this step once @planar is reimplemented pA′, pB′, pAB′ = reorder_planar_indices(indA, pA, indB, pB, pAB) @assert _isplanar(indA, pA′) "not a planar contraction (indA = $indA, pA′ = $pA′)" - @assert _isplanar(indB, pB′) "not a planar contraction (pB′ = $pB′)" - @assert _isplanar(indC, pAB′) "not a planar contraction (pAB′ = $pAB′)" + @assert _isplanar(indB, pB′) "not a planar contraction (indB = $indB, pB′ = $pB′)" + @assert _isplanar(indAB, pAB′) "not a planar contraction (indAB = $indAB, pAB′ = $pAB′)" if BraidingStyle(sectortype(spacetype(C))) == Bosonic() return contract!(C, A, pA′, B, pB′, pAB′, α, β, backend...) diff --git a/src/planar/postprocessors.jl b/src/planar/postprocessors.jl index 7e9bd8e51..c191f3005 100644 --- a/src/planar/postprocessors.jl +++ b/src/planar/postprocessors.jl @@ -51,10 +51,10 @@ end # planar operations, immediately inserting them with `GlobalRef`. # NOTE: work around a somewhat unfortunate interface choice in TensorOperations, which we will correct in the future. -_planaradd!(C, p, A, α, β, backend...) = planaradd!(C, A, p, α, β, backend...) -_planartrace!(C, p, A, q, α, β, backend...) = planartrace!(C, A, p, q, α, β, backend...) +_planaradd!(C, p, A, α, β, backend...) = planaradd!(C, A, :N, p, α, β, backend...) +_planartrace!(C, p, A, q, α, β, backend...) = planartrace!(C, A, :N, p, q, α, β, backend...) function _planarcontract!(C, pAB, A, pA, B, pB, α, β, backend...) - return planarcontract!(C, A, pA, B, pB, pAB, α, β, backend...) + return planarcontract!(C, A, pA, :N, B, pB, :N, pAB, α, β, backend...) end # TODO: replace _planarmethod with planarmethod in everything below const _PLANAR_OPERATIONS = (:_planaradd!, :_planartrace!, :_planarcontract!) From 42db3f0163c7106d415ea3d2d94220e95116b8b2 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Wed, 15 May 2024 19:31:14 +0200 Subject: [PATCH 22/29] cherrypick some fix from master --- src/spaces/homspace.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/spaces/homspace.jl b/src/spaces/homspace.jl index abcfb5efe..c1c4e6aff 100644 --- a/src/spaces/homspace.jl +++ b/src/spaces/homspace.jl @@ -115,3 +115,22 @@ function dim(W::HomSpace) end return d end + +# Operations on HomSpaces +# ----------------------- +function permute(W::HomSpace{S}, (p₁, p₂)::Index2Tuple{N₁,N₂}) where {S,N₁,N₂} + cod = ProductSpace{S,N₁}(map(n -> W[n], p₁)) + dom = ProductSpace{S,N₂}(map(n -> dual(W[n]), p₂)) + return cod ← dom +end + +""" + compose(W::HomSpace, V::HomSpace) + +Obtain the HomSpace that is obtained from composing the morphisms in `W` and `V`. For this +to be possible, the domain of `W` must match the codomain of `V`. +""" +function compose(W::HomSpace{S}, V::HomSpace{S}) where {S} + domain(W) == codomain(V) || throw(SpaceMismatch("$(domain(W)) ≠ $(codomain(V))")) + return HomSpace(codomain(W), domain(V)) +end From a5106f40bc468472d13841b2144f071e9b053419 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Wed, 15 May 2024 19:31:26 +0200 Subject: [PATCH 23/29] Add plancon --- src/TensorKit.jl | 3 +- src/planar/plancon.jl | 139 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 141 insertions(+), 1 deletion(-) create mode 100644 src/planar/plancon.jl diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 8b3bb4ab5..092ef50f4 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -78,7 +78,7 @@ export OrthogonalFactorizationAlgorithm, QR, QRpos, QL, QLpos, LQ, LQpos, RQ, RQ SVD, SDD, Polar # tensor operations -export @tensor, @tensoropt, @ncon, ncon, @planar, @plansor +export @tensor, @tensoropt, @ncon, ncon, @planar, @plansor, plancon export scalar, add!, contract! # truncation schemes @@ -203,6 +203,7 @@ include("planar/macros.jl") @specialize include("planar/planaroperations.jl") include("planar/functions.jl") +include("planar/plancon.jl") # deprecations: to be removed in version 1.0 or sooner include("auxiliary/deprecate.jl") diff --git a/src/planar/plancon.jl b/src/planar/plancon.jl new file mode 100644 index 000000000..e5fca6d7d --- /dev/null +++ b/src/planar/plancon.jl @@ -0,0 +1,139 @@ +""" + plancon(tensorlist, indexlist, [conjlist]; order = ..., output = ...) + +Contract the tensors in `tensorlist` (of type `Vector` or `Tuple`) according to the network +as specified by `indexlist`. Here, `indexlist` is a list (i.e. a `Vector` or `Tuple`) with +the same length as `tensorlist` whose entries are themselves lists (preferably +`Vector{Int}`) where every integer entry provides a label for corresponding index/dimension +of the corresponding tensor in `tensorlist`. Positive integers are used to label indices +that need to be contracted, and such thus appear in two different entries within +`indexlist`, whereas negative integers are used to label indices of the output tensor, and +should appear only once. + +Optional arguments in another list with the same length, `conjlist`, whose entries are of +type `Bool` and indicate whether the corresponding tensor object should be conjugated +(`true`) or not (`false`). The default is `false` for all entries. + +By default, contractions are performed in the order such that the indices being contracted +over are labelled by increasing integers, i.e. first the contraction corresponding to label +`1` is performed. The output tensor had an index order corresponding to decreasing +(negative, so increasing in absolute value) index labels. The keyword arguments `order` and +`output` allow to change these defaults. + +See also the macro version [`@planar`](@ref). +""" +function plancon(tensors, network, + conjlist=fill(false, length(tensors)); + order=nothing, output=nothing) + length(tensors) == length(network) == length(conjlist) || + throw(ArgumentError("number of tensors and of index lists should be the same")) + # TO.isnconstyle(network) || throw(ArgumentError("invalid NCON network: $network")) + output′ = planconoutput(network, output) + + if length(tensors) == 1 + error("not implemented") + if length(output′) == length(network[1]) + return tensorcopy(output′, tensors[1], network[1], conjlist[1] ? :C : :N) + else + return tensortrace(output′, tensors[1], network[1], conjlist[1] ? :C : :N) + end + end + + (tensors, network) = TO.resolve_traces(tensors, network) + tree = order === nothing ? plancontree(network) : TO.indexordertree(network, order) + return planarcontracttree(tensors, network, conjlist, tree, output′) +end + +# single tensor case +function planarcontracttree(tensors, network, conjlist, tree::Int, output) + # extract data + A = tensors[tree] + IA = canonicalize_labels(A, network[tree]) + conjA = conjlist[tree] ? :C : :N + + pA, qA = planartrace_indices(IA, conjA, output) + + if isempty(qA[1]) # no traced indices + return planarcopy(A, pA, conjA) + C = tensoralloc_add(scalartype(A), pA, A, conjA) + return planaradd!(C, A, pA, conjA, One(), Zero()) + else + return planartrace(A, pA, qA, conjA) + end +end + +# recursive case +function planarcontracttree(tensors, network, conjlist, tree::Int) + # extract data + A = tensors[tree] + IA = canonicalize_labels(A, network[tree]) + conjA = conjlist[tree] ? :C : :N + + pA, qA = planartrace_indices(IA, conjA) + + if isempty(qA[1]) # no traced indices + C = A + IC = IA + conjC = conjA + else + C = planartrace(A, pA, qA, conjA) + IC = (TupleTools.getindices(linearize(IA), pA[1]), + TupleTools.getindices(linearize(IA), pA[2])) + conjC = :N + end + + return C, IC, conjC +end + + +function planarcontracttree(tensors, network, conjlist, tree) + @assert !(tree isa Int) "single-node tree should already have been handled" + A, IA, CA = planarcontracttree(tensors, network, conjlist, tree[1]) + B, IB, CB = planarcontracttree(tensors, network, conjlist, tree[2]) + pA, pB, pAB = planarcontract_indices(IA, CA, IB, CB) + + C = planarcontract(A, pA, CA, B, pB, CB, pAB) + + # deduce labels of C + IAB = (TupleTools.getindices(linearize(IA), pA[1])..., + TupleTools.getindices(linearize(IB), pB[2])...) + IC = (TupleTools.getindices(IAB, pAB[1]), TupleTools.getindices(IAB, pAB[2])) + + return C, IC, :N +end +# special case for last step -- dispatch on output argument +function planarcontracttree(tensors, network, conjlist, tree, output) + @assert !(tree isa Int) "single-node tree should already have been handled" + A, IA, CA = planarcontracttree(tensors, network, conjlist, tree[1]) + B, IB, CB = planarcontracttree(tensors, network, conjlist, tree[2]) + pA, pB, pAB = planarcontract_indices(IA, CA, IB, CB, output) + + return planarcontract(A, pA, CA, B, pB, CB, pAB) +end + +function planconoutput(network, output::Union{Nothing,Tuple{Tuple,Tuple}}) + outputindices = Vector{Int}() + for n in network + for k in n + if k < 0 + push!(outputindices, k) + end + end + end + isnothing(output) && return (tuple(sort(outputindices; rev=true)...), ()) + + issetequal(TO.linearize(output), outputindices) || + throw(ArgumentError("invalid NCON network: $network -> $output")) + return output +end + +function plancontree(network) + contractionindices = Vector{Vector{Int}}(undef, length(network)) + for k in 1:length(network) + indices = network[k] + # trace indices have already been removed, remove open indices by filtering on positive values + contractionindices[k] = Base.filter(>(0), indices) + end + partialtrees = collect(Any, 1:length(network)) + return TO._ncontree!(partialtrees, contractionindices) +end \ No newline at end of file From d6b625b0a462f5c52c652b1575b60c579730c08a Mon Sep 17 00:00:00 2001 From: lkdvos Date: Wed, 15 May 2024 19:31:45 +0200 Subject: [PATCH 24/29] update planar tests --- test/planar.jl | 68 ++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 63 insertions(+), 5 deletions(-) diff --git a/test/planar.jl b/test/planar.jl index 44cbd4961..24d7a2086 100644 --- a/test/planar.jl +++ b/test/planar.jl @@ -1,6 +1,6 @@ using TensorKit, TensorOperations, Test using TensorKit: planaradd!, planartrace!, planarcontract! -using TensorKit: PlanarTrivial, ℙ +using TensorKit: PlanarTrivial, ℙ, BraidingTensor """ force_planar(obj) @@ -57,6 +57,14 @@ using TensorKit: planarcontract_indices @test pA == ((4, 1, 2), (5, 3)) @test pB == ((2, 1), (3, 4)) @test pC == ((1, 2, 3), (4, 5)) + + IA = ((-1, 7), (6,)) + IB = ((-4, -3, 6), (-2, 7)) + IC = ((-1, -2), (-3, -4)) + pA, pB, pAB = planarcontract_indices(IA, IB, IC) + @test pA == ((1,), (3, 2)) + @test pB == ((3, 5), (2, 1, 4)) + @test pAB == ((1, 4), (2, 3)) end using TensorKit: reorder_planar_indices @@ -270,7 +278,7 @@ end p = ((4, 3), (5, 2, 1)) @test force_planar(tensoradd!(C, p, A, :N, true, true)) ≈ - planaradd!(C′, A′, p, true, true) + planaradd!(C′, A′, :N, p, true, true) end @testset "planartrace" begin @@ -282,7 +290,7 @@ end q = ((1,), (3,)) @test force_planar(tensortrace!(C, p, A, q, :N, true, true)) ≈ - planartrace!(C′, A′, p, q, true, true) + planartrace!(C′, p, A′, q, :N, true, true) end @testset "planarcontract" begin @@ -299,10 +307,27 @@ end pAB = ((3, 2, 1), (4, 5)) @test force_planar(tensorcontract!(C, pAB, A, pA, :N, B, pB, :N, true, true)) ≈ - planarcontract!(C′, A′, pA, B′, pB, pAB, true, true) + planarcontract!(C′, A′, pA, :N, B′, pB, :N, pAB, true, true) end end +@testset "plancon" verbose = true begin + V = CartesianSpace(2) + A = TensorMap(rand, Float64, V ← V) + + AA = plancon([A, A], [[-1, 1], [1, -2]]) + AA′ = transpose(A * A, ((1, 2), ())) + @test AA ≈ AA′ + + AAA = plancon([A, A, A], [[-1, 1], [1, 2], [2, -2]]) + AAA′ = transpose(A * A * A, ((1, 2), ())) + @test AAA ≈ AAA′ + + AAA = plancon([A, A, A], [[-1, 1], [1, 2], [2, -2]]; output=((-1,), (-2,))) + AAA′ = A * A * A + @test AAA ≈ AAA′ +end + @testset "@planar" verbose = true begin T = ComplexF64 @@ -347,7 +372,10 @@ end @tensor y[-1 -2; -3] := GL[-1 2; 1] * x[1 3; 4] * O[2 -2; 3 5] * GR[4 5; -3] @planar y′[-1 -2; -3] := GL′[-1 2; 1] * x′[1 3; 4] * O′[2 -2; 3 5] * GR′[4 5; -3] + y″ = plancon([GL′, x′, O′, GR′], [[-1, 2, 1], [1, 3, 4], [2, -2, 3, 5], [4, 5, -3]]; + output=((-1, -2), (-3,))) @test force_planar(y) ≈ y′ + @test force_planar(y) ≈ y″ # ∂AC2 # ------- @@ -358,7 +386,11 @@ end GR[1 2; -3] @planar y2′[-1 -2; -3 -4] := GL′[-1 7; 6] * x2′[6 5; 1 3] * O′[7 -2; 5 4] * O′[4 -4; 3 2] * GR′[1 2; -3] + y2″ = plancon([GL′, x2′, O′, O′, GR′], + [[-1, 7, 6], [6, 5, 1, 3], [7, -2, 5, 4], + [4, -4, 3, 2], [1, 2, -3]]; output=((-1, -2), (-3, -4))) @test force_planar(y2) ≈ y2′ + @test force_planar(y2) ≈ y2″ # transfer matrix # ---------------- @@ -366,13 +398,22 @@ end v′ = force_planar(v) @tensor ρ[-1; -2] := x[-1 2; 1] * conj(x[-2 2; 3]) * v[1; 3] @planar ρ′[-1; -2] := x′[-1 2; 1] * conj(x′[-2 2; 3]) * v′[1; 3] + + ρ‴ = ncon([x, x, v], [[-1, 2, 1], [-2, 2, 3], [1, 3]], [false, true, false]) + ρ″ = plancon([x′, x′, v′], [[-1, 2, 1], [-2, 2, 3], [1, 3]], [false, true, false]; output=((-1,), (-2,))) @test force_planar(ρ) ≈ ρ′ - + @test force_planar(ρ) ≈ ρ″ + @tensor ρ2[-1 -2; -3] := GL[1 -2; 3] * x[3 2; -3] * conj(x[1 2; -1]) @plansor ρ3[-1 -2; -3] := GL[1 2; 4] * x[4 5; -3] * τ[2 3; 5 -2] * conj(x[1 3; -1]) @planar ρ2′[-1 -2; -3] := GL′[1 2; 4] * x′[4 5; -3] * τ[2 3; 5 -2] * conj(x′[1 3; -1]) + τtensor = BraidingTensor(space(x′, 2), space(GL′, 2)') + ρ2″ = plancon([GL′, x′, τtensor, x′], [[1, 2, 4], [4, 5, -3], [2, 3, 5, -2], [1, 3, -1]], + [false, false, false, true]; output=((-1, -2), (-3,))) + @test force_planar(ρ2) ≈ ρ2′ + @test force_planar(ρ2) ≈ ρ2″ @test ρ2 ≈ ρ3 # Periodic boundary conditions @@ -389,8 +430,16 @@ end @planar O_periodic′[-1 -2; -3 -4] := O′[1 2; -3 6] * f1′[-1; 1 3 5] * conj(f2′[-4; 6 7 8]) * τ[2 3; 7 4] * τ[4 5; 8 -2] + τtensor1 = BraidingTensor(space(f2′, 3)', space(O′, 2)') + τtensor2 = BraidingTensor(space(f2′, 4)', space(O′, 2)') + O_periodic″ = plancon([O′, f1′, f2′, τtensor1, τtensor2], + [[1, 2, -3, 6], [-1, 1, 3, 5], [-4, 6, 7, 8], [2, 3, 7, 4], + [4, 5, 8, -2]], + [false, false, true, false, false]; + output=((-1, -2), (-3, -4))) @test O_periodic1 ≈ O_periodic2 @test force_planar(O_periodic1) ≈ O_periodic′ + @test force_planar(O_periodic1) ≈ O_periodic″ end @testset "MERA networks" begin @@ -420,7 +469,16 @@ end ((w′[12 14; 20] * conj(w′[13 14; 23])) * ρ′[18 19 20; 21 22 23])) * w′[16 15; 18]) * conj(w′[16 17; 21])) end + C″ = tensorscalar(plancon([h′, u′, u′, u′, w′, u′, w′, w′, w′, ρ′, w′, w′], + [[9, 3, 4, 5, 1, 2], [1, 2, 7, 12], [3, 4, 11, 13], + [8, 5, 15, 6], [6, 7, 19], [8, 9, 17, 10], [10, 11, 22], + [12, 14, 20], [13, 14, 23], [18, 19, 20, 21, 22, 23], + [16, 15, 18], [16, 17, 21]], + [false, false, true, false, false, true, true, false, + true, false, false, true])) + @test C ≈ C′ + @test C ≈ C″ end @testset "Issue 93" begin From 8bd8a15529e7c7586438265fd8e15cd953217a3f Mon Sep 17 00:00:00 2001 From: lkdvos Date: Fri, 17 May 2024 10:04:25 +0200 Subject: [PATCH 25/29] Updates to make planar AD work --- ext/TensorKitChainRulesCoreExt.jl | 93 ++++++++++++++++++++----------- src/planar/functions.jl | 12 ++-- src/planar/indices.jl | 22 ++++---- src/planar/planaroperations.jl | 18 +++--- src/planar/plancon.jl | 9 ++- src/planar/postprocessors.jl | 2 +- test/ad.jl | 20 +++---- test/planar.jl | 18 +++--- 8 files changed, 109 insertions(+), 85 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index c455ad308..40ff4c69e 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -637,11 +637,12 @@ end # Planar rrules # -------------- -function ChainRulesCore.rrule(::typeof(TensorKit.planaradd!), C::AbstractTensorMap{S,N₁,N₂}, - A::AbstractTensorMap{S}, p::Index2Tuple{N₁,N₂}, +function ChainRulesCore.rrule(::typeof(TensorKit.planaradd!), + C::AbstractTensorMap, + A::AbstractTensorMap, pA::Index2Tuple, conjA::Symbol, α::Number, β::Number, - backend::Backend...) where {S,N₁,N₂} - C′ = planaradd!(copy(C), A, p, α, β, backend...) + backend::Backend...) + C′ = planaradd!(copy(C), A, pA, conjA, α, β, backend...) projectA = ProjectTo(A) projectC = ProjectTo(C) @@ -653,42 +654,43 @@ function ChainRulesCore.rrule(::typeof(TensorKit.planaradd!), C::AbstractTensorM dC = @thunk projectC(scale(ΔC, conj(β))) dA = @thunk begin - ip = _canonicalize(invperm(linearize(p)), A) + ip = _canonicalize(invperm(linearize(pA)), A) _dA = zerovector(A, VectorInterface.promote_add(ΔC, α)) - _dA = planaradd!(_dA, ΔC, ip, conj(α), Zero(), backend...) + _dA = planaradd!(_dA, ΔC, ip, conjA, conjA == :N ? conj(α) : α, Zero(), + backend...) return projectA(_dA) end dα = @thunk begin - p′ = TensorKit.adjointtensorindices(A, p) - _dα = tensorscalar(planarcontract(A', ((), linearize(p′)), - ΔC, (trivtuple(p), ()), + _dα = tensorscalar(planarcontract(A, ((), linearize(pA)), _conj(conjA), + ΔC, (trivtuple(pA), ()), :N, ((), ()), One(), backend...)) return projectα(_dα) end dβ = @thunk begin - p′ = TensorKit.adjointtensorindices(C, trivtuple(p)) - _dβ = tensorscalar(planarcontract(C', ((), p′), - ΔC, (trivtuple(p), ()), + _dβ = tensorscalar(planarcontract(C, + ((), trivtuple(TensorOperations.numind(pA))), + :C, + ΔC, (trivtuple(pA), ()), :N, ((), ()), One(), backend...)) return projectβ(_dβ) end dbackend = map(x -> NoTangent(), backend) - return NoTangent(), dC, dA, NoTangent(), dα, dβ, dbackend... + return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ, dbackend... end return C′, planaradd_pullback end function ChainRulesCore.rrule(::typeof(TensorKit.planarcontract!), - C::AbstractTensorMap{S,N₁,N₂}, - A::AbstractTensorMap{S}, pA::Index2Tuple, - B::AbstractTensorMap{S}, pB::Index2Tuple, - pAB::Index2Tuple{N₁,N₂}, - α::Number, β::Number, backend::Backend...) where {S,N₁,N₂} - indA = (codomainind(A), reverse(domainind(A))) - indB = (codomainind(B), reverse(domainind(B))) - pA, pB, pAB = TensorKit.reorder_planar_indices(indA, pA, indB, pB, pAB) - C′ = planarcontract!(copy(C), A, pA, B, pB, pAB, α, β, backend...) + C::AbstractTensorMap, + A::AbstractTensorMap, pA::Index2Tuple, conjA::Symbol, + B::AbstractTensorMap, pB::Index2Tuple, conjB::Symbol, + pAB::Index2Tuple, + α::Number, β::Number, backend::Backend...) + # indA = (codomainind(A), reverse(domainind(A))) + # indB = (codomainind(B), reverse(domainind(B))) + # pA, pB, pAB = TensorKit.reorder_planar_indices(indA, pA, indB, pB, pAB) + C′ = planarcontract!(copy(C), A, pA, conjA, B, pB, conjB, pAB, α, β, backend...) projectA = ProjectTo(A) projectB = ProjectTo(B) @@ -704,26 +706,34 @@ function ChainRulesCore.rrule(::typeof(TensorKit.planarcontract!), dC = @thunk projectC(scale(ΔC, conj(β))) dA = @thunk begin ipA = _canonicalize(invperm(linearize(pA)), A) + conjΔC = conjA == :C ? :C : :N + conjB′ = conjA == :C ? conjB : _conj(conjB) _dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B), typeof(α))) - pB′ = TensorKit.adjointtensorindices(B, reverse(pB)) - _dA = planarcontract!(_dA, ΔC, pΔC, adjoint(B), pB′, ipA, - conj(α), Zero(), backend...) + _dA = planarcontract!(_dA, ΔC, pΔC, conjΔC, B, reverse(pB), conjB′, ipA, + conjA == :C ? α : conj(α), Zero(), backend...) return projectA(_dA) end dB = @thunk begin ipB = _canonicalize((invperm(linearize(pB)), ()), B) + conjΔC = conjB == :C ? :C : :N + conjA′ = conjB == :C ? conjA : _conj(conjA) _dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A), typeof(α))) - pA′ = TensorKit.adjointtensorindices(A, reverse(pA)) - _dB = planarcontract!(_dB, adjoint(A), pA′, ΔC, pΔC, ipB, - conj(α), Zero(), backend...) + _dB = planarcontract!(_dB, + A, reverse(pA), conjA′, + ΔC, pΔC, conjΔC, + ipB, conjB == :C ? α : conj(α), Zero(), backend...) return projectB(_dB) end dα = @thunk begin - AB = planarcontract!(similar(C), A, pA, B, pB, pAB, One(), Zero(), backend...) - p′ = TensorKit.adjointtensorindices(AB, trivtuple(pAB)) - _dα = tensorscalar(planarcontract(AB', ((), p′), - ΔC, (trivtuple(pAB), ()), ((), ()), - One(), backend...)) + _dα = tensorscalar(planarcontract(planarcontract(A, pA, conjA, + B, pB, conjB, + pAB, One(), backend...), + ((), trivtuple(TensorOperations.numind(pAB))), + :C, + ΔC, + (trivtuple(TensorOperations.numind(pAB)), ()), + :N, + ((), ()), One(), backend...)) return projectα(_dα) end dβ = @thunk begin @@ -734,13 +744,28 @@ function ChainRulesCore.rrule(::typeof(TensorKit.planarcontract!), return projectβ(_dβ) end dbackend = map(x -> NoTangent(), backend) - return NoTangent(), dC, dA, NoTangent(), dB, NoTangent(), NoTangent(), + return NoTangent(), dC, dA, NoTangent(), NoTangent(), dB, NoTangent(), NoTangent(), + NoTangent(), dα, dβ, dbackend... end return C′, planarcontract_pullback end +function ChainRulesCore.rrule(::typeof(TensorKit.planartrace!), + C::AbstractTensorMap, + A::AbstractTensorMap, + p::Index2Tuple, q::Index2Tuple, conjA::Symbol, + α::Number, β::Number, backend::Backend...) + C′ = planartrace!(copy(C), A, p, q, conjA, α, β, backend...) + + function planartrace_pullback(ΔC′) + return ΔC = unthunk(ΔC′) + end + + return C′, planartrace_pullback +end + # Convert rrules #---------------- function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{Dict}, t::AbstractTensorMap) diff --git a/src/planar/functions.jl b/src/planar/functions.jl index e2d6ec65b..301cdab50 100644 --- a/src/planar/functions.jl +++ b/src/planar/functions.jl @@ -12,8 +12,8 @@ end # ------------------------------------------------------------------------------------------ - -function planartrace(A, pA::Index2Tuple, qA::Index2Tuple, conjA::Symbol, α::Number=One(), backend::Backend...) +function planartrace(A, pA::Index2Tuple, qA::Index2Tuple, conjA::Symbol, α::Number=One(), + backend::Backend...) TC = TO.promote_contract(scalartype(A), scalartype(α)) C = tensoralloc_add(TC, pA, A, conjA) return planartrace!(C, A, pA, qA, conjA, α, Zero(), backend...) @@ -21,8 +21,6 @@ end # ------------------------------------------------------------------------------------------ - - """ planarcontract(A, IA, [conjA], B, IB, [conjB], [IC], [α=1]) planarcontract(A, pA::Index2Tuple, conjA, B, pB::Index2Tuple, conjB, pAB::Index2Tuple, α=1, [backend]) # expert mode @@ -43,7 +41,8 @@ See also [`tensorcontract`](@ref). """ function planarcontract end -function planarcontract(A, IA::TensorLabels, conjA::Symbol, B, IB::TensorLabels, conjB::Symbol, IC::TensorLabels, +function planarcontract(A, IA::TensorLabels, conjA::Symbol, B, IB::TensorLabels, + conjB::Symbol, IC::TensorLabels, α::Number=One()) ia = canonicalize_labels(A, IA) ib = canonicalize_labels(B, IB) @@ -52,7 +51,8 @@ function planarcontract(A, IA::TensorLabels, conjA::Symbol, B, IB::TensorLabels, return planarcontract(A, pA, conjA, B, pB, conjB, pAB, α) end # default `IC` -function planarcontract(A, IA::TensorLabels, conjA::Symbol, B, IB::TensorLabels, conjB::Symbol, α::Number=One()) +function planarcontract(A, IA::TensorLabels, conjA::Symbol, B, IB::TensorLabels, + conjB::Symbol, α::Number=One()) ia = canonicalize_labels(A, IA) ib = canonicalize_labels(B, IB) pA, pB, pAB = planarcontract_indices(ia, ib) diff --git a/src/planar/indices.jl b/src/planar/indices.jl index 51a06aaee..468edae9a 100644 --- a/src/planar/indices.jl +++ b/src/planar/indices.jl @@ -14,12 +14,12 @@ end function planartrace_indices(IA::Tuple{Tuple,Tuple}, conjA, IC::Tuple{Tuple,Tuple}) IA′ = conjA == :C ? reverse(IA) : IA - + IA_linear = (IA′[1]..., (IA′[2])...) IC_linear = (IC[1]..., (IC[2])...) - + p, q1, q2 = TO.trace_indices(IA_linear, IC_linear) - + if conjA == :C p′ = adjointtensorindices((length(IC[1]), length(IC[2])), p) q′ = adjointtensorindices((length(IA[2]), length(IA[1])), (q1, q2)) @@ -32,7 +32,7 @@ function planartrace_indices(IA::Tuple{Tuple,Tuple}, conjA, IC::Tuple{Tuple,Tupl end function planartrace_indices(IA::Tuple{Tuple,Tuple}, conjA) IA′ = conjA == :C ? reverse(IA) : IA - + IA_linear = (IA′[1]..., (IA′[2])...) IC_linear = tuple(TO.unique2(IA_linear)...) p, q1, q2 = TO.trace_indices(IA_linear, IC_linear) @@ -44,11 +44,10 @@ function planartrace_indices(IA::Tuple{Tuple,Tuple}, conjA) p′ = p q′ = (q1, q2) end - + return p′, q′ end - """ planarcontract_indices(IA, IB, [IC]) @@ -58,16 +57,15 @@ achieved in a planar manner. function planarcontract_indices(IA::Tuple{Tuple,Tuple}, conjA::Symbol, IB::Tuple{Tuple,Tuple}, conjB::Symbol, IC::Tuple{Tuple,Tuple}) - IA′ = conjA == :C ? reverse(IA) : IA IB′ = conjB == :C ? reverse(IB) : IB - + pA, pB, pAB = planarcontract_indices(IA′, IB′, IC) - + # map indices back to original tensor pA′ = conjA == :C ? adjointtensorindices((length(IA[2]), length(IA[1])), pA) : pA pB′ = conjB == :C ? adjointtensorindices((length(IB[2]), length(IB[1])), pB) : pB - + return pA′, pB′, pAB end function planarcontract_indices(IA::Tuple{Tuple,Tuple}, conjA::Symbol, @@ -75,9 +73,9 @@ function planarcontract_indices(IA::Tuple{Tuple,Tuple}, conjA::Symbol, # map indices to indices of adjoint tensor IA′ = conjA == :C ? reverse(IA) : IA IB′ = conjB == :C ? reverse(IB) : IB - + pA, pB, pAB = planarcontract_indices(IA′, IB′) - + # map indices back to original tensor pA′ = conjA == :C ? adjointtensorindices((length(IA[2]), length(IA[1])), pA) : pA pB′ = conjB == :C ? adjointtensorindices((length(IB[2]), length(IB[1])), pB) : pB diff --git a/src/planar/planaroperations.jl b/src/planar/planaroperations.jl index 8df67b539..1ddf0347f 100644 --- a/src/planar/planaroperations.jl +++ b/src/planar/planaroperations.jl @@ -3,23 +3,23 @@ # ---------- function planaradd!(C::AbstractTensorMap{S}, - A::AbstractTensorMap{S}, conjA::Symbol, - p::Index2Tuple, + A::AbstractTensorMap{S}, pA::Index2Tuple, conjA::Symbol, α::Number, β::Number, backend::Backend...) where {S} if conjA == :N A′ = A - p′ = _canonicalize(p, C) + p′ = _canonicalize(pA, C) elseif conjA == :C A′ = adjoint(A) - p′ = adjointtensorindices(A, _canonicalize(p, C)) + p′ = adjointtensorindices(A, _canonicalize(pA, C)) else throw(ArgumentError("unknown conjugation flag $conjA")) end return add_transpose!(C, A′, p′, α, β, backend...) end -function planartrace!(C::AbstractTensorMap{S}, p::Index2Tuple, - A::AbstractTensorMap{S}, q::Index2Tuple, conjA::Symbol, +function planartrace!(C::AbstractTensorMap{S}, + A::AbstractTensorMap{S}, p::Index2Tuple, q::Index2Tuple, + conjA::Symbol, α::Number, β::Number, backend::Backend...) where {S} if conjA == :N A′ = A @@ -80,7 +80,7 @@ function trace_transpose!(tdst::AbstractTensorMap{S,N₁,N₂}, q₁ = $(q₁), q₂ = $(q₂)")) # TODO: check planarity? end - + # TODO: not sure if this is worth it if BraidingStyle(sectortype(S)) == Bosonic() return @inbounds trace_permute!(tdst, tsrc, p, q, α, β, backend...) @@ -88,7 +88,7 @@ function trace_transpose!(tdst::AbstractTensorMap{S,N₁,N₂}, scale!(tdst, β) β′ = One() - + for (f₁, f₂) in fusiontrees(tsrc) @inbounds A = tsrc[f₁, f₂] for ((f₁′, f₂′), coeff) in planar_trace(f₁, f₂, p₁, p₂, q₁, q₂) @@ -109,7 +109,7 @@ function _planarcontract!(C::AbstractTensorMap{S}, indA = (codomainind(A), reverse(domainind(A))) indB = (codomainind(B), reverse(domainind(B))) indAB = (ntuple(identity, N₁), reverse(ntuple(i -> i + N₁, N₂))) - + # TODO: avoid this step once @planar is reimplemented pA′, pB′, pAB′ = reorder_planar_indices(indA, pA, indB, pB, pAB) diff --git a/src/planar/plancon.jl b/src/planar/plancon.jl index e5fca6d7d..d70bb31c7 100644 --- a/src/planar/plancon.jl +++ b/src/planar/plancon.jl @@ -50,9 +50,9 @@ function planarcontracttree(tensors, network, conjlist, tree::Int, output) A = tensors[tree] IA = canonicalize_labels(A, network[tree]) conjA = conjlist[tree] ? :C : :N - + pA, qA = planartrace_indices(IA, conjA, output) - + if isempty(qA[1]) # no traced indices return planarcopy(A, pA, conjA) C = tensoralloc_add(scalartype(A), pA, A, conjA) @@ -81,11 +81,10 @@ function planarcontracttree(tensors, network, conjlist, tree::Int) TupleTools.getindices(linearize(IA), pA[2])) conjC = :N end - + return C, IC, conjC end - function planarcontracttree(tensors, network, conjlist, tree) @assert !(tree isa Int) "single-node tree should already have been handled" A, IA, CA = planarcontracttree(tensors, network, conjlist, tree[1]) @@ -136,4 +135,4 @@ function plancontree(network) end partialtrees = collect(Any, 1:length(network)) return TO._ncontree!(partialtrees, contractionindices) -end \ No newline at end of file +end diff --git a/src/planar/postprocessors.jl b/src/planar/postprocessors.jl index 083871ffd..46b4ccb03 100644 --- a/src/planar/postprocessors.jl +++ b/src/planar/postprocessors.jl @@ -52,7 +52,7 @@ end # planar operations, immediately inserting them with `GlobalRef`. # NOTE: work around a somewhat unfortunate interface choice in TensorOperations, which we will correct in the future. -_planaradd!(C, p, A, α, β, backend...) = planaradd!(C, A, :N, p, α, β, backend...) +_planaradd!(C, p, A, α, β, backend...) = planaradd!(C, A, p, :N, α, β, backend...) _planartrace!(C, p, A, q, α, β, backend...) = planartrace!(C, A, :N, p, q, α, β, backend...) function _planarcontract!(C, pAB, A, pA, B, pB, α, β, backend...) return planarcontract!(C, A, pA, :N, B, pB, :N, pAB, α, β, backend...) diff --git a/test/ad.jl b/test/ad.jl index 9587916dd..dc158e730 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -238,40 +238,40 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), C = _randomize!(TensorOperations.tensoralloc_add(T, p, A, :N, false)) α = randn(T) β = randn(T) - test_rrule(planaradd!, C, A, p, α, β; atol, rtol) + test_rrule(planaradd!, C, A, p, :N, α, β; atol, rtol) end - + @testset "planarcontract! 1" begin A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) B = TensorMap(randn, T, V[1] ⊗ V[5] ← V[5] ⊗ V[2]) - pA = ((1, 3, 4), (5, 2)) + pA = ((4, 3, 1), (5, 2)) pB = ((2, 4), (1, 3)) - pAB = ((3, 2, 1), (4, 5)) - + pAB = ((1, 2, 3), (4, 5)) + α = randn(T) β = randn(T) C = _randomize!(TensorOperations.tensoralloc_contract(T, pAB, A, pA, :N, B, pB, :N, false)) - test_rrule(planarcontract!, C, A, pA, B, pB, pAB, α, β; atol, rtol) + test_rrule(planarcontract!, C, A, pA, :N, B, pB, :N, pAB, α, β; atol, rtol) end - + @testset "planarcontract! 2" begin A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) B = TensorMap(randn, T, V[3] ⊗ V[4] ⊗ V[5] ← V[1] ⊗ V[2]) pA = ((1, 2), (3, 4, 5)) pB = ((1, 2, 3), (4, 5)) pAB = ((1, 2), (3, 4)) - + α = randn(T) β = randn(T) C = _randomize!(TensorOperations.tensoralloc_contract(T, pAB, A, pA, :N, B, pB, :N, false)) - test_rrule(planarcontract!, C, A, pA, B, pB, pAB, α, β; atol, rtol) + test_rrule(planarcontract!, C, A, pA, :N, B, pB, :N, pAB, α, β; atol, rtol) end end - + @testset "Factorizations with scalartype $T" for T in (Float64, ComplexF64) A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) B = TensorMap(randn, T, space(A)') diff --git a/test/planar.jl b/test/planar.jl index 24d7a2086..4a373d6f4 100644 --- a/test/planar.jl +++ b/test/planar.jl @@ -57,7 +57,7 @@ using TensorKit: planarcontract_indices @test pA == ((4, 1, 2), (5, 3)) @test pB == ((2, 1), (3, 4)) @test pC == ((1, 2, 3), (4, 5)) - + IA = ((-1, 7), (6,)) IB = ((-4, -3, 6), (-2, 7)) IC = ((-1, -2), (-3, -4)) @@ -278,7 +278,7 @@ end p = ((4, 3), (5, 2, 1)) @test force_planar(tensoradd!(C, p, A, :N, true, true)) ≈ - planaradd!(C′, A′, :N, p, true, true) + planaradd!(C′, A′, p, :N, true, true) end @testset "planartrace" begin @@ -290,7 +290,7 @@ end q = ((1,), (3,)) @test force_planar(tensortrace!(C, p, A, q, :N, true, true)) ≈ - planartrace!(C′, p, A′, q, :N, true, true) + planartrace!(C′, A′, p, q, :N, true, true) end @testset "planarcontract" begin @@ -398,20 +398,22 @@ end v′ = force_planar(v) @tensor ρ[-1; -2] := x[-1 2; 1] * conj(x[-2 2; 3]) * v[1; 3] @planar ρ′[-1; -2] := x′[-1 2; 1] * conj(x′[-2 2; 3]) * v′[1; 3] - + ρ‴ = ncon([x, x, v], [[-1, 2, 1], [-2, 2, 3], [1, 3]], [false, true, false]) - ρ″ = plancon([x′, x′, v′], [[-1, 2, 1], [-2, 2, 3], [1, 3]], [false, true, false]; output=((-1,), (-2,))) + ρ″ = plancon([x′, x′, v′], [[-1, 2, 1], [-2, 2, 3], [1, 3]], [false, true, false]; + output=((-1,), (-2,))) @test force_planar(ρ) ≈ ρ′ @test force_planar(ρ) ≈ ρ″ - + @tensor ρ2[-1 -2; -3] := GL[1 -2; 3] * x[3 2; -3] * conj(x[1 2; -1]) @plansor ρ3[-1 -2; -3] := GL[1 2; 4] * x[4 5; -3] * τ[2 3; 5 -2] * conj(x[1 3; -1]) @planar ρ2′[-1 -2; -3] := GL′[1 2; 4] * x′[4 5; -3] * τ[2 3; 5 -2] * conj(x′[1 3; -1]) τtensor = BraidingTensor(space(x′, 2), space(GL′, 2)') - ρ2″ = plancon([GL′, x′, τtensor, x′], [[1, 2, 4], [4, 5, -3], [2, 3, 5, -2], [1, 3, -1]], + ρ2″ = plancon([GL′, x′, τtensor, x′], + [[1, 2, 4], [4, 5, -3], [2, 3, 5, -2], [1, 3, -1]], [false, false, false, true]; output=((-1, -2), (-3,))) - + @test force_planar(ρ2) ≈ ρ2′ @test force_planar(ρ2) ≈ ρ2″ @test ρ2 ≈ ρ3 From 6de651183f98ffa78c6f24974149b5841176ac26 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Fri, 17 May 2024 10:06:47 +0200 Subject: [PATCH 26/29] Prevent tensoroperations test on anyonic tests --- test/ad.jl | 117 +++++++++++++++++++++++++++-------------------------- 1 file changed, 59 insertions(+), 58 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index dc158e730..6b80d0beb 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -164,69 +164,70 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), test_rrule(LinearAlgebra.norm, A, 2) end - @testset "TensorOperations with scalartype $T" for T in (Float64, ComplexF64) - atol = precision(T) - rtol = precision(T) - - @testset "tensortrace!" begin - A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[1] ⊗ V[5]) - pC = ((3, 5), (2,)) - pA = ((1,), (4,)) - α = randn(T) - β = randn(T) - - C = _randomize!(TensorOperations.tensoralloc_add(T, pC, A, :N, false)) - test_rrule(tensortrace!, C, pC, A, pA, :N, α, β; atol, rtol) - - C = _randomize!(TensorOperations.tensoralloc_add(T, pC, A, :C, false)) - test_rrule(tensortrace!, C, pC, A, pA, :C, α, β; atol, rtol) - end - - @testset "tensoradd!" begin - p = ((1, 3, 2), (5, 4)) - A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) - C = _randomize!(TensorOperations.tensoralloc_add(T, p, A, :N, false)) - α = randn(T) - β = randn(T) - test_rrule(tensoradd!, C, p, A, :N, α, β; atol, rtol) - - C = _randomize!(TensorOperations.tensoralloc_add(T, p, A, :C, false)) - test_rrule(tensoradd!, C, p, A, :C, α, β; atol, rtol) - end - - @testset "tensorcontract!" begin - A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) - B = TensorMap(randn, T, V[3] ⊗ V[1]' ← V[2]) - pC = ((3, 2), (4, 1)) - pA = ((2, 4, 5), (1, 3)) - pB = ((2, 1), (3,)) - α = randn(T) - β = randn(T) - - C = _randomize!(TensorOperations.tensoralloc_contract(T, pC, A, pA, :N, - B, pB, :N, false)) - test_rrule(tensorcontract!, C, pC, A, pA, :N, B, pB, :N, α, β; atol, rtol) + BraidingStyle(sectortype(eltype(V))) isa Symmetric && + @testset "TensorOperations with scalartype $T" for T in (Float64, ComplexF64) + atol = precision(T) + rtol = precision(T) + + @testset "tensortrace!" begin + A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[1] ⊗ V[5]) + pC = ((3, 5), (2,)) + pA = ((1,), (4,)) + α = randn(T) + β = randn(T) + + C = _randomize!(TensorOperations.tensoralloc_add(T, pC, A, :N, false)) + test_rrule(tensortrace!, C, pC, A, pA, :N, α, β; atol, rtol) + + C = _randomize!(TensorOperations.tensoralloc_add(T, pC, A, :C, false)) + test_rrule(tensortrace!, C, pC, A, pA, :C, α, β; atol, rtol) + end - A2 = TensorMap(randn, T, V[1]' ⊗ V[2]' ← V[3]' ⊗ V[4]' ⊗ V[5]') - C = _randomize!(TensorOperations.tensoralloc_contract(T, pC, A2, pA, :C, - B, pB, :N, false)) - test_rrule(tensorcontract!, C, pC, A2, pA, :C, B, pB, :N, α, β; atol, rtol) + @testset "tensoradd!" begin + p = ((1, 3, 2), (5, 4)) + A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + C = _randomize!(TensorOperations.tensoralloc_add(T, p, A, :N, false)) + α = randn(T) + β = randn(T) + test_rrule(tensoradd!, C, p, A, :N, α, β; atol, rtol) - B2 = TensorMap(randn, T, V[3]' ⊗ V[1] ← V[2]') - C = _randomize!(TensorOperations.tensoralloc_contract(T, pC, A, pA, :N, - B2, pB, :C, false)) - test_rrule(tensorcontract!, C, pC, A, pA, :N, B2, pB, :C, α, β; atol, rtol) + C = _randomize!(TensorOperations.tensoralloc_add(T, p, A, :C, false)) + test_rrule(tensoradd!, C, p, A, :C, α, β; atol, rtol) + end - C = _randomize!(TensorOperations.tensoralloc_contract(T, pC, A2, pA, :C, - B2, pB, :C, false)) - test_rrule(tensorcontract!, C, pC, A2, pA, :C, B2, pB, :C, α, β; atol, rtol) - end + @testset "tensorcontract!" begin + A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + B = TensorMap(randn, T, V[3] ⊗ V[1]' ← V[2]) + pC = ((3, 2), (4, 1)) + pA = ((2, 4, 5), (1, 3)) + pB = ((2, 1), (3,)) + α = randn(T) + β = randn(T) + + C = _randomize!(TensorOperations.tensoralloc_contract(T, pC, A, pA, :N, + B, pB, :N, false)) + test_rrule(tensorcontract!, C, pC, A, pA, :N, B, pB, :N, α, β; atol, rtol) + + A2 = TensorMap(randn, T, V[1]' ⊗ V[2]' ← V[3]' ⊗ V[4]' ⊗ V[5]') + C = _randomize!(TensorOperations.tensoralloc_contract(T, pC, A2, pA, :C, + B, pB, :N, false)) + test_rrule(tensorcontract!, C, pC, A2, pA, :C, B, pB, :N, α, β; atol, rtol) + + B2 = TensorMap(randn, T, V[3]' ⊗ V[1] ← V[2]') + C = _randomize!(TensorOperations.tensoralloc_contract(T, pC, A, pA, :N, + B2, pB, :C, false)) + test_rrule(tensorcontract!, C, pC, A, pA, :N, B2, pB, :C, α, β; atol, rtol) + + C = _randomize!(TensorOperations.tensoralloc_contract(T, pC, A2, pA, :C, + B2, pB, :C, false)) + test_rrule(tensorcontract!, C, pC, A2, pA, :C, B2, pB, :C, α, β; atol, rtol) + end - @testset "tensorscalar" begin - A = Tensor(randn, T, ProductSpace{typeof(V[1]),0}()) - test_rrule(tensorscalar, A) + @testset "tensorscalar" begin + A = Tensor(randn, T, ProductSpace{typeof(V[1]),0}()) + test_rrule(tensorscalar, A) + end end - end @testset "PlanarOperations with scalartype $T" for T in (Float64, ComplexF64) atol = precision(T) From f9847a172fe861e006b9d05e2a85348d384648d6 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Fri, 17 May 2024 10:25:58 +0200 Subject: [PATCH 27/29] Fix `Base.summary(::TensorMap)` --- src/tensors/adjoint.jl | 4 +--- src/tensors/tensor.jl | 5 ++--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/tensors/adjoint.jl b/src/tensors/adjoint.jl index 43ad110d3..ffe67a5c6 100644 --- a/src/tensors/adjoint.jl +++ b/src/tensors/adjoint.jl @@ -87,9 +87,7 @@ end # Show #------ -function Base.summary(t::AdjointTensorMap) - return print("AdjointTensorMap(", codomain(t), " ← ", domain(t), ")") -end +Base.summary(io::IO, t::AdjointTensorMap) = print(io, "AdjointTensorMap($(space(t)))") function Base.show(io::IO, t::AdjointTensorMap{S}) where {S<:IndexSpace} if get(io, :compact, false) print(io, "AdjointTensorMap(", codomain(t), " ← ", domain(t), ")") diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index 8c8347ddc..86e07c692 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -679,9 +679,8 @@ end # Show #------ -function Base.summary(t::TensorMap) - return print("TensorMap(", space(t), ")") -end +Base.summary(io::IO, t::TensorMap) = print(io, "TensorMap(", space(t), ")") + function Base.show(io::IO, t::TensorMap{S}) where {S<:IndexSpace} if get(io, :compact, false) print(io, "TensorMap(", space(t), ")") From e8b900dd3b38e53b1d8d4c468738623dd5e34f1b Mon Sep 17 00:00:00 2001 From: lkdvos Date: Fri, 17 May 2024 10:38:38 +0200 Subject: [PATCH 28/29] make otimes AD planar-compatible --- ext/TensorKitChainRulesCoreExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index 40ff4c69e..c174cbf8d 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -118,7 +118,7 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe pB = (allind(B), ()) dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B))) - dA = tensorcontract!(dA, ipA, ΔC, pΔC, :N, B, pB, :C) + dA = planarcontract!(dA, ΔC, pΔC, :N, B, pB, :C, ipA, One(), Zero()) return projectA(dA) end dB_ = @thunk begin @@ -126,7 +126,7 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe pA = ((), allind(A)) dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A))) - dB = tensorcontract!(dB, ipB, A, pA, :C, ΔC, pΔC, :N) + dB = planarcontract!(dB, A, pA, :C, ΔC, pΔC, :N, ipB, One(), Zero()) return projectB(dB) end return NoTangent(), dA_, dB_ From 236e3f1b0ae90d70dd3a3f751374faf199151e34 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Fri, 17 May 2024 12:17:45 +0200 Subject: [PATCH 29/29] Formatter --- src/tensors/braidingtensor.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index c41b26bcb..a870e0c6e 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -233,7 +233,6 @@ function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, pAB::Index2Tuple{N₁,N₂}, α::Number, β::Number, backend::Backend...) where {S,N₁,N₂,N₃} - indA = (codomainind(A), reverse(domainind(A))) indB = (codomainind(B), reverse(domainind(B))) pA, pB, pAB = reorder_planar_indices(indA, pA, indB, pB, pAB) @@ -283,13 +282,12 @@ function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, backend::Backend...) where {S,N₁,N₂,N₃} codA, domA = codomainind(A), domainind(A) codB, domB = codomainind(B), domainind(B) - + indA = (codomainind(A), reverse(domainind(A))) indB = (codomainind(B), reverse(domainind(B))) pA, pB, pAB = reorder_planar_indices(indA, pA, indB, pB, pAB) oindA, cindA = pA cindB, oindB = pB - if space(B, cindB[1]) != space(A, cindA[1])' || space(B, cindB[2]) != space(A, cindA[2])'