Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using LinearAlgebra
using TupleTools

import TensorOperations as TO
using TensorOperations: promote_contract
using TensorOperations: promote_contract, tensoralloc_add, tensoralloc_contract
using VectorInterface: promote_scale, promote_add

include("utility.jl")
Expand Down
37 changes: 23 additions & 14 deletions ext/TensorKitChainRulesCoreExt/tensoroperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
dC = @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk let
ipA = invperm(linearize(pA))
_dA = zerovector(A, promote_add(ΔC, α))
_dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), Zero(), ba...)
pdA = _repartition(ipA, A)
TA = promote_add(ΔC, α)
# TODO: allocator
_dA = tensoralloc_add(TA, ΔC, pdA, conjA, Val(false))
_dA = tensoradd!(_dA, ΔC, pdA, conjA, conjA ? α : conj(α), Zero(), ba...)
return projectA(_dA)
end
dα = @thunk let
Expand Down Expand Up @@ -55,34 +58,37 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
function pullback(ΔC′)
ΔC = unthunk(ΔC′)
ipAB = invperm(linearize(pAB))
pΔC = (TupleTools.getindices(ipAB, trivtuple(TO.numout(pA))),
TupleTools.getindices(ipAB, TO.numout(pA) .+ trivtuple(TO.numin(pB))))
pΔC = _repartition(ipAB, TO.numout(pA))

dC = @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk let
ipA = (invperm(linearize(pA)), ())
ipA = _repartition(invperm(linearize(pA)), A)
conjΔC = conjA
conjB′ = conjA ? conjB : !conjB
_dA = zerovector(A,
promote_contract(scalartype(ΔC), scalartype(B), scalartype(α)))
TA = promote_contract(scalartype(ΔC), scalartype(B), scalartype(α))
# TODO: allocator
tB = twist(B,
TupleTools.vcat(filter(x -> !isdual(space(B, x)), pB[1]),
filter(x -> isdual(space(B, x)), pB[2])))
_dA = tensoralloc_contract(TA, ΔC, pΔC, conjΔC, tB, reverse(pB), conjB′, ipA,
Val(false))
_dA = tensorcontract!(_dA,
ΔC, pΔC, conjΔC,
tB, reverse(pB), conjB′, ipA,
conjA ? α : conj(α), Zero(), ba...)
return projectA(_dA)
end
dB = @thunk let
ipB = (invperm(linearize(pB)), ())
ipB = _repartition(invperm(linearize(pB)), B)
conjΔC = conjB
conjA′ = conjB ? conjA : !conjA
_dB = zerovector(B,
promote_contract(scalartype(ΔC), scalartype(A), scalartype(α)))
TB = promote_contract(scalartype(ΔC), scalartype(A), scalartype(α))
# TODO: allocator
tA = twist(A,
TupleTools.vcat(filter(x -> isdual(space(A, x)), pA[1]),
filter(x -> !isdual(space(A, x)), pA[2])))
_dB = tensoralloc_contract(TB, tA, reverse(pA), conjA′, ΔC, pΔC, conjΔC, ipB,
Val(false))
_dB = tensorcontract!(_dB,
tA, reverse(pA), conjA′,
ΔC, pΔC, conjΔC, ipB,
Expand Down Expand Up @@ -121,12 +127,15 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensortrace!),
dC = @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk let
ip = invperm((linearize(p)..., q[1]..., q[2]...))
pdA = _repartition(ip, A)
E = one!(TO.tensoralloc_add(scalartype(A), A, q, conjA))
twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E)))
_dA = zerovector(A, promote_scale(ΔC, α))
_dA = tensorproduct!(_dA, ΔC,
(trivtuple(TO.numind(p)), ()), conjA, E,
((), trivtuple(TO.numind(q))), conjA, (ip, ()),
pE = ((), trivtuple(TO.numind(q)))
pΔC = (trivtuple(TO.numind(p)), ())
TA = promote_scale(ΔC, α)
# TODO: allocator
_dA = tensoralloc_contract(TA, ΔC, pΔC, conjA, E, pE, conjA, pdA, Val(false))
_dA = tensorproduct!(_dA, ΔC, pΔC, conjA, E, pE, conjA, pdA,
conjA ? α : conj(α), Zero(), ba...)
return projectA(_dA)
end
Expand Down
13 changes: 8 additions & 5 deletions ext/TensorKitChainRulesCoreExt/utility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@
# -------
trivtuple(N) = ntuple(identity, N)

function _repartition(p::IndexTuple, N₁::Int)
Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int)
length(p) >= N₁ ||
throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)"))
return p[1:N₁], p[(N₁ + 1):end]
return TupleTools.getindices(p, trivtuple(N₁)),
TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁)
end
Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int)
return _repartition(linearize(p), N₁)

Check warning on line 12 in ext/TensorKitChainRulesCoreExt/utility.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/utility.jl#L11-L12

Added lines #L11 - L12 were not covered by tests
end
_repartition(p::Index2Tuple, N₁::Int) = _repartition(linearize(p), N₁)
function _repartition(p::Union{IndexTuple,Index2Tuple}, ::Index2Tuple{N₁}) where {N₁}
return _repartition(p, N₁)
end
function _repartition(p::Union{IndexTuple,Index2Tuple},
::AbstractTensorMap{<:Any,N₁}) where {N₁}
return _repartition(p, N₁)
t::AbstractTensorMap)
return _repartition(p, TensorKit.numout(t))
end

TensorKit.block(t::ZeroTangent, c::Sector) = t
Expand Down
14 changes: 14 additions & 0 deletions test/bugfixes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
tensorfree!(t2)
end

# https://github.com/Jutho/TensorKit.jl/issues/201
@testset "Issue #201" begin
function f(A::AbstractTensorMap)
U, S, V, = tsvd(A)
Expand Down Expand Up @@ -71,4 +72,17 @@
grad4, = Zygote.gradient(g, convert(Array, B₀))
@test convert(Array, grad3) ≈ grad4
end

# https://github.com/Jutho/TensorKit.jl/issues/209
@testset "Issue #209" begin
function f(T, D)
@tensor T[1, 4, 1, 3] * D[3, 4]
end
V = Z2Space(2, 2)
D = DiagonalTensorMap(randn(4), V)
T = randn(V ⊗ V ← V ⊗ V)
g1, = Zygote.gradient(f, T, D)
g2, = Zygote.gradient(f, T, TensorMap(D))
@test g1 ≈ g2
end
end
Loading