diff --git a/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl index e1cba7876..16c7583d1 100644 --- a/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl @@ -8,7 +8,7 @@ using LinearAlgebra using TupleTools import TensorOperations as TO -using TensorOperations: promote_contract +using TensorOperations: promote_contract, tensoralloc_add, tensoralloc_contract using VectorInterface: promote_scale, promote_add include("utility.jl") diff --git a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl index 4f081a035..f3dac3c5b 100644 --- a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl +++ b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl @@ -14,8 +14,11 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!), dC = @thunk projectC(scale(ΔC, conj(β))) dA = @thunk let ipA = invperm(linearize(pA)) - _dA = zerovector(A, promote_add(ΔC, α)) - _dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), Zero(), ba...) + pdA = _repartition(ipA, A) + TA = promote_add(ΔC, α) + # TODO: allocator + _dA = tensoralloc_add(TA, ΔC, pdA, conjA, Val(false)) + _dA = tensoradd!(_dA, ΔC, pdA, conjA, conjA ? α : conj(α), Zero(), ba...) return projectA(_dA) end dα = @thunk let @@ -55,19 +58,20 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!), function pullback(ΔC′) ΔC = unthunk(ΔC′) ipAB = invperm(linearize(pAB)) - pΔC = (TupleTools.getindices(ipAB, trivtuple(TO.numout(pA))), - TupleTools.getindices(ipAB, TO.numout(pA) .+ trivtuple(TO.numin(pB)))) + pΔC = _repartition(ipAB, TO.numout(pA)) dC = @thunk projectC(scale(ΔC, conj(β))) dA = @thunk let - ipA = (invperm(linearize(pA)), ()) + ipA = _repartition(invperm(linearize(pA)), A) conjΔC = conjA conjB′ = conjA ? conjB : !conjB - _dA = zerovector(A, - promote_contract(scalartype(ΔC), scalartype(B), scalartype(α))) + TA = promote_contract(scalartype(ΔC), scalartype(B), scalartype(α)) + # TODO: allocator tB = twist(B, TupleTools.vcat(filter(x -> !isdual(space(B, x)), pB[1]), filter(x -> isdual(space(B, x)), pB[2]))) + _dA = tensoralloc_contract(TA, ΔC, pΔC, conjΔC, tB, reverse(pB), conjB′, ipA, + Val(false)) _dA = tensorcontract!(_dA, ΔC, pΔC, conjΔC, tB, reverse(pB), conjB′, ipA, @@ -75,14 +79,16 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!), return projectA(_dA) end dB = @thunk let - ipB = (invperm(linearize(pB)), ()) + ipB = _repartition(invperm(linearize(pB)), B) conjΔC = conjB conjA′ = conjB ? conjA : !conjA - _dB = zerovector(B, - promote_contract(scalartype(ΔC), scalartype(A), scalartype(α))) + TB = promote_contract(scalartype(ΔC), scalartype(A), scalartype(α)) + # TODO: allocator tA = twist(A, TupleTools.vcat(filter(x -> isdual(space(A, x)), pA[1]), filter(x -> !isdual(space(A, x)), pA[2]))) + _dB = tensoralloc_contract(TB, tA, reverse(pA), conjA′, ΔC, pΔC, conjΔC, ipB, + Val(false)) _dB = tensorcontract!(_dB, tA, reverse(pA), conjA′, ΔC, pΔC, conjΔC, ipB, @@ -121,12 +127,15 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensortrace!), dC = @thunk projectC(scale(ΔC, conj(β))) dA = @thunk let ip = invperm((linearize(p)..., q[1]..., q[2]...)) + pdA = _repartition(ip, A) E = one!(TO.tensoralloc_add(scalartype(A), A, q, conjA)) twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) - _dA = zerovector(A, promote_scale(ΔC, α)) - _dA = tensorproduct!(_dA, ΔC, - (trivtuple(TO.numind(p)), ()), conjA, E, - ((), trivtuple(TO.numind(q))), conjA, (ip, ()), + pE = ((), trivtuple(TO.numind(q))) + pΔC = (trivtuple(TO.numind(p)), ()) + TA = promote_scale(ΔC, α) + # TODO: allocator + _dA = tensoralloc_contract(TA, ΔC, pΔC, conjA, E, pE, conjA, pdA, Val(false)) + _dA = tensorproduct!(_dA, ΔC, pΔC, conjA, E, pE, conjA, pdA, conjA ? α : conj(α), Zero(), ba...) return projectA(_dA) end diff --git a/ext/TensorKitChainRulesCoreExt/utility.jl b/ext/TensorKitChainRulesCoreExt/utility.jl index 5bdd4e4a0..f270e346a 100644 --- a/ext/TensorKitChainRulesCoreExt/utility.jl +++ b/ext/TensorKitChainRulesCoreExt/utility.jl @@ -2,18 +2,21 @@ # ------- trivtuple(N) = ntuple(identity, N) -function _repartition(p::IndexTuple, N₁::Int) +Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int) length(p) >= N₁ || throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) - return p[1:N₁], p[(N₁ + 1):end] + return TupleTools.getindices(p, trivtuple(N₁)), + TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁) +end +Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int) + return _repartition(linearize(p), N₁) end -_repartition(p::Index2Tuple, N₁::Int) = _repartition(linearize(p), N₁) function _repartition(p::Union{IndexTuple,Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} return _repartition(p, N₁) end function _repartition(p::Union{IndexTuple,Index2Tuple}, - ::AbstractTensorMap{<:Any,N₁}) where {N₁} - return _repartition(p, N₁) + t::AbstractTensorMap) + return _repartition(p, TensorKit.numout(t)) end TensorKit.block(t::ZeroTangent, c::Sector) = t diff --git a/test/bugfixes.jl b/test/bugfixes.jl index 39c28ad29..faa348c0d 100644 --- a/test/bugfixes.jl +++ b/test/bugfixes.jl @@ -44,6 +44,7 @@ tensorfree!(t2) end + # https://github.com/Jutho/TensorKit.jl/issues/201 @testset "Issue #201" begin function f(A::AbstractTensorMap) U, S, V, = tsvd(A) @@ -71,4 +72,17 @@ grad4, = Zygote.gradient(g, convert(Array, B₀)) @test convert(Array, grad3) ≈ grad4 end + + # https://github.com/Jutho/TensorKit.jl/issues/209 + @testset "Issue #209" begin + function f(T, D) + @tensor T[1, 4, 1, 3] * D[3, 4] + end + V = Z2Space(2, 2) + D = DiagonalTensorMap(randn(4), V) + T = randn(V ⊗ V ← V ⊗ V) + g1, = Zygote.gradient(f, T, D) + g2, = Zygote.gradient(f, T, TensorMap(D)) + @test g1 ≈ g2 + end end