From 27f3c46d0d1fe682e768d0c855a14bda8c2eb9a4 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 27 Jan 2025 16:28:27 -0500 Subject: [PATCH 1/7] Rewrite AD for TensorOperations in terms of `similar` instead of `zerovector` --- ext/TensorKitChainRulesCoreExt/tensoroperations.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl index 4f081a035..b251c7be0 100644 --- a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl +++ b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl @@ -14,7 +14,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!), dC = @thunk projectC(scale(ΔC, conj(β))) dA = @thunk let ipA = invperm(linearize(pA)) - _dA = zerovector(A, promote_add(ΔC, α)) + _dA = similar(A, promote_add(ΔC, α)) _dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), Zero(), ba...) return projectA(_dA) end @@ -63,8 +63,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!), ipA = (invperm(linearize(pA)), ()) conjΔC = conjA conjB′ = conjA ? conjB : !conjB - _dA = zerovector(A, - promote_contract(scalartype(ΔC), scalartype(B), scalartype(α))) + _dA = similar(A, promote_contract(scalartype(ΔC), scalartype(B), scalartype(α))) tB = twist(B, TupleTools.vcat(filter(x -> !isdual(space(B, x)), pB[1]), filter(x -> isdual(space(B, x)), pB[2]))) @@ -78,8 +77,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!), ipB = (invperm(linearize(pB)), ()) conjΔC = conjB conjA′ = conjB ? conjA : !conjA - _dB = zerovector(B, - promote_contract(scalartype(ΔC), scalartype(A), scalartype(α))) + _dB = similar(B, promote_contract(scalartype(ΔC), scalartype(A), scalartype(α))) tA = twist(A, TupleTools.vcat(filter(x -> isdual(space(A, x)), pA[1]), filter(x -> !isdual(space(A, x)), pA[2]))) @@ -123,7 +121,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensortrace!), ip = invperm((linearize(p)..., q[1]..., q[2]...)) E = one!(TO.tensoralloc_add(scalartype(A), A, q, conjA)) twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) - _dA = zerovector(A, promote_scale(ΔC, α)) + _dA = similar(A, promote_scale(ΔC, α)) _dA = tensorproduct!(_dA, ΔC, (trivtuple(TO.numind(p)), ()), conjA, E, ((), trivtuple(TO.numind(q))), conjA, (ip, ()), From e4f764cd101d9430888cc38370ad6f95be90c3d7 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 27 Jan 2025 16:31:14 -0500 Subject: [PATCH 2/7] Add testcase --- test/bugfixes.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/bugfixes.jl b/test/bugfixes.jl index 39c28ad29..e65fbe7ba 100644 --- a/test/bugfixes.jl +++ b/test/bugfixes.jl @@ -71,4 +71,16 @@ grad4, = Zygote.gradient(g, convert(Array, B₀)) @test convert(Array, grad3) ≈ grad4 end + + @testset "Issue #209" begin + function f(T, D) + @tensor T[1, 4, 1, 3] * D[3, 4] + end + V = Z2Space(2, 2) + D = DiagonalTensorMap(randn(4), V) + T = randn(V ⊗ V ← V ⊗ V) + g1, = Zygote.gradient(f, T, D) + g2, = Zygote.gradient(f, T, TensorMap(D)) + @test g1 ≈ g2 + end end From 0372740b137796a20c251cbeb1b7857426cc88ad Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 27 Jan 2025 16:31:51 -0500 Subject: [PATCH 3/7] Add links --- test/bugfixes.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/bugfixes.jl b/test/bugfixes.jl index e65fbe7ba..faa348c0d 100644 --- a/test/bugfixes.jl +++ b/test/bugfixes.jl @@ -44,6 +44,7 @@ tensorfree!(t2) end + # https://github.com/Jutho/TensorKit.jl/issues/201 @testset "Issue #201" begin function f(A::AbstractTensorMap) U, S, V, = tsvd(A) @@ -72,6 +73,7 @@ @test convert(Array, grad3) ≈ grad4 end + # https://github.com/Jutho/TensorKit.jl/issues/209 @testset "Issue #209" begin function f(T, D) @tensor T[1, 4, 1, 3] * D[3, 4] From c07e7770fe540596b354a8fbd65d5b02cb87a800 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 30 Jan 2025 07:19:38 -0500 Subject: [PATCH 4/7] Rewrite in terms of tensoralloc --- .../TensorKitChainRulesCoreExt.jl | 2 +- .../tensoroperations.jl | 28 +++++++++++++------ 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl index e1cba7876..16c7583d1 100644 --- a/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl @@ -8,7 +8,7 @@ using LinearAlgebra using TupleTools import TensorOperations as TO -using TensorOperations: promote_contract +using TensorOperations: promote_contract, tensoralloc_add, tensoralloc_contract using VectorInterface: promote_scale, promote_add include("utility.jl") diff --git a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl index b251c7be0..f06bba05c 100644 --- a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl +++ b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl @@ -14,8 +14,11 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!), dC = @thunk projectC(scale(ΔC, conj(β))) dA = @thunk let ipA = invperm(linearize(pA)) - _dA = similar(A, promote_add(ΔC, α)) - _dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), Zero(), ba...) + pdA = (ipA, ()) + TA = promote_add(ΔC, α) + # TODO: allocator + _dA = tensoralloc_add(TA, ΔC, pdA, conjA, Val(false)) + _dA = tensoradd!(_dA, ΔC, pdA, conjA, conjA ? α : conj(α), Zero(), ba...) return projectA(_dA) end dα = @thunk let @@ -63,10 +66,13 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!), ipA = (invperm(linearize(pA)), ()) conjΔC = conjA conjB′ = conjA ? conjB : !conjB - _dA = similar(A, promote_contract(scalartype(ΔC), scalartype(B), scalartype(α))) + TA = promote_contract(scalartype(ΔC), scalartype(B), scalartype(α)) + # TODO: allocator tB = twist(B, TupleTools.vcat(filter(x -> !isdual(space(B, x)), pB[1]), filter(x -> isdual(space(B, x)), pB[2]))) + _dA = tensoralloc_contract(TA, ΔC, pΔC, conjΔC, tB, reverse(pB), conjB′, ipA, + Val(false)) _dA = tensorcontract!(_dA, ΔC, pΔC, conjΔC, tB, reverse(pB), conjB′, ipA, @@ -77,10 +83,13 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!), ipB = (invperm(linearize(pB)), ()) conjΔC = conjB conjA′ = conjB ? conjA : !conjA - _dB = similar(B, promote_contract(scalartype(ΔC), scalartype(A), scalartype(α))) + TB = promote_contract(scalartype(ΔC), scalartype(A), scalartype(α)) + # TODO: allocator tA = twist(A, TupleTools.vcat(filter(x -> isdual(space(A, x)), pA[1]), filter(x -> !isdual(space(A, x)), pA[2]))) + _dB = tensoralloc_contract(TB, tA, reverse(pA), conjA′, ΔC, pΔC, conjΔC, ipB, + Val(false)) _dB = tensorcontract!(_dB, tA, reverse(pA), conjA′, ΔC, pΔC, conjΔC, ipB, @@ -119,12 +128,15 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensortrace!), dC = @thunk projectC(scale(ΔC, conj(β))) dA = @thunk let ip = invperm((linearize(p)..., q[1]..., q[2]...)) + pdA = (ip, ()) E = one!(TO.tensoralloc_add(scalartype(A), A, q, conjA)) twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) - _dA = similar(A, promote_scale(ΔC, α)) - _dA = tensorproduct!(_dA, ΔC, - (trivtuple(TO.numind(p)), ()), conjA, E, - ((), trivtuple(TO.numind(q))), conjA, (ip, ()), + pE = ((), trivtuple(TO.numind(q))) + pΔC = (trivtuple(TO.numind(p)), ()) + TA = promote_scale(ΔC, α) + # TODO: allocator + _dA = tensoralloc_contract(TA, ΔC, pΔC, conjA, E, pE, conjA, pdA, Val(false)) + _dA = tensorproduct!(_dA, ΔC, pΔC, conjA, E, pE, conjA, pdA, conjA ? α : conj(α), Zero(), ba...) return projectA(_dA) end From d0205dec28970fb664ad99c35aa9c42113b43380 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 30 Jan 2025 07:33:54 -0500 Subject: [PATCH 5/7] Be consistent with indextuple lengths --- ext/TensorKitChainRulesCoreExt/tensoroperations.jl | 11 +++++------ ext/TensorKitChainRulesCoreExt/utility.jl | 7 ++++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl index f06bba05c..f3dac3c5b 100644 --- a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl +++ b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl @@ -14,7 +14,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!), dC = @thunk projectC(scale(ΔC, conj(β))) dA = @thunk let ipA = invperm(linearize(pA)) - pdA = (ipA, ()) + pdA = _repartition(ipA, A) TA = promote_add(ΔC, α) # TODO: allocator _dA = tensoralloc_add(TA, ΔC, pdA, conjA, Val(false)) @@ -58,12 +58,11 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!), function pullback(ΔC′) ΔC = unthunk(ΔC′) ipAB = invperm(linearize(pAB)) - pΔC = (TupleTools.getindices(ipAB, trivtuple(TO.numout(pA))), - TupleTools.getindices(ipAB, TO.numout(pA) .+ trivtuple(TO.numin(pB)))) + pΔC = _repartition(ipAB, TO.numout(pA)) dC = @thunk projectC(scale(ΔC, conj(β))) dA = @thunk let - ipA = (invperm(linearize(pA)), ()) + ipA = _repartition(invperm(linearize(pA)), A) conjΔC = conjA conjB′ = conjA ? conjB : !conjB TA = promote_contract(scalartype(ΔC), scalartype(B), scalartype(α)) @@ -80,7 +79,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!), return projectA(_dA) end dB = @thunk let - ipB = (invperm(linearize(pB)), ()) + ipB = _repartition(invperm(linearize(pB)), B) conjΔC = conjB conjA′ = conjB ? conjA : !conjA TB = promote_contract(scalartype(ΔC), scalartype(A), scalartype(α)) @@ -128,7 +127,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensortrace!), dC = @thunk projectC(scale(ΔC, conj(β))) dA = @thunk let ip = invperm((linearize(p)..., q[1]..., q[2]...)) - pdA = (ip, ()) + pdA = _repartition(ip, A) E = one!(TO.tensoralloc_add(scalartype(A), A, q, conjA)) twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) pE = ((), trivtuple(TO.numind(q))) diff --git a/ext/TensorKitChainRulesCoreExt/utility.jl b/ext/TensorKitChainRulesCoreExt/utility.jl index 5bdd4e4a0..f2fc19aab 100644 --- a/ext/TensorKitChainRulesCoreExt/utility.jl +++ b/ext/TensorKitChainRulesCoreExt/utility.jl @@ -5,15 +5,16 @@ trivtuple(N) = ntuple(identity, N) function _repartition(p::IndexTuple, N₁::Int) length(p) >= N₁ || throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) - return p[1:N₁], p[(N₁ + 1):end] + return TupleTools.getindices(p, trivtuple(N₁)), + TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁) end _repartition(p::Index2Tuple, N₁::Int) = _repartition(linearize(p), N₁) function _repartition(p::Union{IndexTuple,Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} return _repartition(p, N₁) end function _repartition(p::Union{IndexTuple,Index2Tuple}, - ::AbstractTensorMap{<:Any,N₁}) where {N₁} - return _repartition(p, N₁) + t::AbstractTensorMap) + return _repartition(p, TensorKit.numout(t)) end TensorKit.block(t::ZeroTangent, c::Sector) = t From faaa89b5e6e78647d7c0db58197067efd5a64bdc Mon Sep 17 00:00:00 2001 From: Jutho Date: Thu, 30 Jan 2025 18:22:01 +0100 Subject: [PATCH 6/7] try fix (without testing) --- ext/TensorKitChainRulesCoreExt/tensoroperations.jl | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl index f3dac3c5b..86733a175 100644 --- a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl +++ b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl @@ -61,8 +61,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!), pΔC = _repartition(ipAB, TO.numout(pA)) dC = @thunk projectC(scale(ΔC, conj(β))) - dA = @thunk let - ipA = _repartition(invperm(linearize(pA)), A) + dA = @thunk let ipA = _repartition(invperm(linearize(pA)), A) conjΔC = conjA conjB′ = conjA ? conjB : !conjB TA = promote_contract(scalartype(ΔC), scalartype(B), scalartype(α)) @@ -78,8 +77,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!), conjA ? α : conj(α), Zero(), ba...) return projectA(_dA) end - dB = @thunk let - ipB = _repartition(invperm(linearize(pB)), B) + dB = @thunk let ipB = _repartition(invperm(linearize(pB)), B) conjΔC = conjB conjA′ = conjB ? conjA : !conjA TB = promote_contract(scalartype(ΔC), scalartype(A), scalartype(α)) @@ -125,9 +123,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensortrace!), function pullback(ΔC′) ΔC = unthunk(ΔC′) dC = @thunk projectC(scale(ΔC, conj(β))) - dA = @thunk let - ip = invperm((linearize(p)..., q[1]..., q[2]...)) - pdA = _repartition(ip, A) + dA = @thunk let ip = invperm((linearize(p)..., q[1]..., q[2]...)), pdA = _repartition(ip, A) E = one!(TO.tensoralloc_add(scalartype(A), A, q, conjA)) twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) pE = ((), trivtuple(TO.numind(q))) From 5fd86c028bc866f2c623200728580b522c5d208b Mon Sep 17 00:00:00 2001 From: Jutho Date: Thu, 30 Jan 2025 22:14:03 +0100 Subject: [PATCH 7/7] proper fix (hopefullly) --- ext/TensorKitChainRulesCoreExt/tensoroperations.jl | 10 +++++++--- ext/TensorKitChainRulesCoreExt/utility.jl | 6 ++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl index 86733a175..f3dac3c5b 100644 --- a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl +++ b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl @@ -61,7 +61,8 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!), pΔC = _repartition(ipAB, TO.numout(pA)) dC = @thunk projectC(scale(ΔC, conj(β))) - dA = @thunk let ipA = _repartition(invperm(linearize(pA)), A) + dA = @thunk let + ipA = _repartition(invperm(linearize(pA)), A) conjΔC = conjA conjB′ = conjA ? conjB : !conjB TA = promote_contract(scalartype(ΔC), scalartype(B), scalartype(α)) @@ -77,7 +78,8 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!), conjA ? α : conj(α), Zero(), ba...) return projectA(_dA) end - dB = @thunk let ipB = _repartition(invperm(linearize(pB)), B) + dB = @thunk let + ipB = _repartition(invperm(linearize(pB)), B) conjΔC = conjB conjA′ = conjB ? conjA : !conjA TB = promote_contract(scalartype(ΔC), scalartype(A), scalartype(α)) @@ -123,7 +125,9 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensortrace!), function pullback(ΔC′) ΔC = unthunk(ΔC′) dC = @thunk projectC(scale(ΔC, conj(β))) - dA = @thunk let ip = invperm((linearize(p)..., q[1]..., q[2]...)), pdA = _repartition(ip, A) + dA = @thunk let + ip = invperm((linearize(p)..., q[1]..., q[2]...)) + pdA = _repartition(ip, A) E = one!(TO.tensoralloc_add(scalartype(A), A, q, conjA)) twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) pE = ((), trivtuple(TO.numind(q))) diff --git a/ext/TensorKitChainRulesCoreExt/utility.jl b/ext/TensorKitChainRulesCoreExt/utility.jl index f2fc19aab..f270e346a 100644 --- a/ext/TensorKitChainRulesCoreExt/utility.jl +++ b/ext/TensorKitChainRulesCoreExt/utility.jl @@ -2,13 +2,15 @@ # ------- trivtuple(N) = ntuple(identity, N) -function _repartition(p::IndexTuple, N₁::Int) +Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int) length(p) >= N₁ || throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) return TupleTools.getindices(p, trivtuple(N₁)), TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁) end -_repartition(p::Index2Tuple, N₁::Int) = _repartition(linearize(p), N₁) +Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int) + return _repartition(linearize(p), N₁) +end function _repartition(p::Union{IndexTuple,Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} return _repartition(p, N₁) end