From fab0dc897bf02663ca90df8de1fec2e05b52e601 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 5 Jan 2026 20:54:56 +0100 Subject: [PATCH 1/2] TensorOperations chainrules optimizations for alpha and beta cases --- .../tensoroperations.jl | 83 ++++++++++++------- 1 file changed, 51 insertions(+), 32 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl index 73b60b5b7..ab054dcbd 100644 --- a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl +++ b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl @@ -1,3 +1,10 @@ +# To avoid computing rrules for α and β when these aren't needed, we want to have a +# type-stable quick bail-out +_needs_tangent(x) = _needs_tangent(typeof(x)) +_needs_tangent(::Type{<:Number}) = true +_needs_tangent(::Type{<:Integer}) = false +_needs_tangent(::Type{<:Union{One, Zero}}) = false + function ChainRulesCore.rrule( ::typeof(TensorOperations.tensoradd!), C::AbstractTensorMap, @@ -13,7 +20,7 @@ function ChainRulesCore.rrule( function pullback(ΔC′) ΔC = unthunk(ΔC′) - dC = @thunk projectC(scale(ΔC, conj(β))) + dC = β === Zero() ? ZeroTangent() : @thunk projectC(scale(ΔC, conj(β))) dA = @thunk let ipA = invperm(linearize(pA)) pdA = _repartition(ipA, A) @@ -21,24 +28,28 @@ function ChainRulesCore.rrule( # TODO: allocator _dA = tensoralloc_add(TA, ΔC, pdA, conjA, Val(false)) _dA = tensoradd!(_dA, ΔC, pdA, conjA, conjA ? α : conj(α), Zero(), ba...) - return projectA(_dA) + projectA(_dA) end - dα = @thunk let - # TODO: this is an inner product implemented as a contraction - # for non-symmetric tensors this might be more efficient like this, - # but for symmetric tensors an intermediate object will anyways be created - # and then it might be more efficient to use an addition and inner product - tΔC = _twist_nocopy(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC))) - _dα = tensorscalar( - tensorcontract( - A, ((), linearize(pA)), !conjA, - tΔC, (trivtuple(TO.numind(pA)), ()), false, - ((), ()), One(), ba... + dα = if _needs_tangent(α) + @thunk let + # TODO: this is an inner product implemented as a contraction + # for non-symmetric tensors this might be more efficient like this, + # but for symmetric tensors an intermediate object will anyways be created + # and then it might be more efficient to use an addition and inner product + tΔC = _twist_nocopy(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC))) + _dα = tensorscalar( + tensorcontract( + A, ((), linearize(pA)), !conjA, + tΔC, (trivtuple(TO.numind(pA)), ()), false, + ((), ()), One(), ba... + ) ) - ) - return projectα(_dα) + projectα(_dα) + end + else + ZeroTangent() end - dβ = @thunk projectβ(inner(C, ΔC)) + dβ = _needs_tangent(β) ? @thunk(projectβ(inner(C, ΔC))) : ZeroTangent() dba = map(_ -> NoTangent(), ba) return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ, dba... end @@ -67,7 +78,7 @@ function ChainRulesCore.rrule( ipAB = invperm(linearize(pAB)) pΔC = _repartition(ipAB, TO.numout(pA)) - dC = @thunk projectC(scale(ΔC, conj(β))) + dC = β === Zero() ? ZeroTangent() : @thunk projectC(scale(ΔC, conj(β))) dA = @thunk let ipA = _repartition(invperm(linearize(pA)), A) conjΔC = conjA @@ -91,7 +102,7 @@ function ChainRulesCore.rrule( ipA, conjA ? α : conj(α), Zero(), ba... ) - return projectA(_dA) + projectA(_dA) end dB = @thunk let ipB = _repartition(invperm(linearize(pB)), B) @@ -116,14 +127,18 @@ function ChainRulesCore.rrule( ipB, conjB ? α : conj(α), Zero(), ba... ) - return projectB(_dB) + projectB(_dB) end - dα = @thunk let - # TODO: this result should be AB = (C′ - βC) / α as C′ = βC + αAB - AB = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) - return projectα(inner(AB, ΔC)) + dα = if _needs_tangent(α) + @thunk let + # TODO: this result should be AB = (C′ - βC) / α as C′ = βC + αAB + AB = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) + projectα(inner(AB, ΔC)) + end + else + ZeroTangent() end - dβ = @thunk projectβ(inner(C, ΔC)) + dβ = _needs_tangent(β) ? @thunk(projectβ(inner(C, ΔC))) : ZeroTangent() dba = map(_ -> NoTangent(), ba) return NoTangent(), dC, dA, NoTangent(), NoTangent(), @@ -149,7 +164,7 @@ function ChainRulesCore.rrule( function pullback(ΔC′) ΔC = unthunk(ΔC′) - dC = @thunk projectC(scale(ΔC, conj(β))) + dC = β === Zero() ? ZeroTangent() : @thunk projectC(scale(ΔC, conj(β))) dA = @thunk let ip = invperm((linearize(p)..., q[1]..., q[2]...)) pdA = _repartition(ip, A) @@ -163,15 +178,19 @@ function ChainRulesCore.rrule( _dA = tensorproduct!( _dA, ΔC, pΔC, conjA, E, pE, conjA, pdA, conjA ? α : conj(α), Zero(), ba... ) - return projectA(_dA) + projectA(_dA) end - dα = @thunk let - # TODO: this result might be easier to compute as: - # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α - At = tensortrace(A, p, q, conjA) - return projectα(inner(At, ΔC)) + dα = if _needs_tangent(α) + @thunk let + # TODO: this result might be easier to compute as: + # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α + At = tensortrace(A, p, q, conjA) + projectα(inner(At, ΔC)) + end + else + ZeroTangent() end - dβ = @thunk projectβ(inner(C, ΔC)) + dβ = _needs_tangent(β) ? @thunk(projectβ(inner(C, ΔC))) : ZeroTangent() dba = map(_ -> NoTangent(), ba) return NoTangent(), dC, dA, NoTangent(), NoTangent(), NoTangent(), dα, dβ, dba... end From 7bbc8301ef63c3010718a72db3e5cc0b19f1ca0b Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 5 Jan 2026 20:57:13 +0100 Subject: [PATCH 2/2] remove temporary `_twist_nocopy` --- ext/TensorKitChainRulesCoreExt/linalg.jl | 4 ++-- .../tensoroperations.jl | 17 +++++------------ 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/linalg.jl b/ext/TensorKitChainRulesCoreExt/linalg.jl index d4510c2f1..fd27d410c 100644 --- a/ext/TensorKitChainRulesCoreExt/linalg.jl +++ b/ext/TensorKitChainRulesCoreExt/linalg.jl @@ -42,7 +42,7 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe ipA = (codomainind(A), domainind(A)) pB = (allind(B), ()) dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B))) - tB = _twist_nocopy(B, filter(x -> isdual(space(B, x)), allind(B))) + tB = twist(B, filter(x -> isdual(space(B, x)), allind(B)); copy = false) dA = tensorcontract!(dA, ΔC, pΔC, false, tB, pB, true, ipA) return projectA(dA) end @@ -50,7 +50,7 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe ipB = (codomainind(B), domainind(B)) pA = ((), allind(A)) dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A))) - tA = _twist_nocopy(A, filter(x -> isdual(space(A, x)), allind(A))) + tA = twist(A, filter(x -> isdual(space(A, x)), allind(A)); copy = false) dB = tensorcontract!(dB, tA, pA, true, ΔC, pΔC, false, ipB) return projectB(dB) end diff --git a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl index ab054dcbd..dfa8eb72c 100644 --- a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl +++ b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl @@ -36,7 +36,7 @@ function ChainRulesCore.rrule( # for non-symmetric tensors this might be more efficient like this, # but for symmetric tensors an intermediate object will anyways be created # and then it might be more efficient to use an addition and inner product - tΔC = _twist_nocopy(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC))) + tΔC = twist(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC)); copy = false) _dα = tensorscalar( tensorcontract( A, ((), linearize(pA)), !conjA, @@ -85,12 +85,12 @@ function ChainRulesCore.rrule( conjB′ = conjA ? conjB : !conjB TA = promote_contract(scalartype(ΔC), scalartype(B), scalartype(α)) # TODO: allocator - tB = _twist_nocopy( + tB = twist( B, TupleTools.vcat( filter(x -> !isdual(space(B, x)), pB[1]), filter(x -> isdual(space(B, x)), pB[2]) - ) + ); copy = false ) _dA = tensoralloc_contract( TA, ΔC, pΔC, conjΔC, tB, reverse(pB), conjB′, ipA, Val(false) @@ -110,12 +110,12 @@ function ChainRulesCore.rrule( conjA′ = conjB ? conjA : !conjA TB = promote_contract(scalartype(ΔC), scalartype(A), scalartype(α)) # TODO: allocator - tA = _twist_nocopy( + tA = twist( A, TupleTools.vcat( filter(x -> isdual(space(A, x)), pA[1]), filter(x -> !isdual(space(A, x)), pA[2]) - ) + ); copy = false ) _dB = tensoralloc_contract( TB, tA, reverse(pA), conjA′, ΔC, pΔC, conjΔC, ipB, Val(false) @@ -207,10 +207,3 @@ function ChainRulesCore.rrule(::typeof(TensorKit.scalar), t::AbstractTensorMap) end return val, scalar_pullback end - -# temporary function to avoid copies when not needed -# TODO: remove once `twist(t; copy=false)` is defined -function _twist_nocopy(t, inds; kwargs...) - (BraidingStyle(sectortype(t)) isa Bosonic || isempty(inds)) && return t - return twist(t, inds; kwargs...) -end