Skip to content

Commit 31c912f

Browse files
lkdvosJutho
andauthored
avoid twist allocating in rrules when not required (#306)
* avoid twist copies in rrules * Add nocopy twists in tensorproduct * Update ext/TensorKitChainRulesCoreExt/tensoroperations.jl Co-authored-by: Jutho <[email protected]> --------- Co-authored-by: Jutho <[email protected]>
1 parent b46bfc4 commit 31c912f

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

ext/TensorKitChainRulesCoreExt/linalg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,15 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe
4242
ipA = (codomainind(A), domainind(A))
4343
pB = (allind(B), ())
4444
dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B)))
45-
tB = twist(B, filter(x -> isdual(space(B, x)), allind(B)))
45+
tB = _twist_nocopy(B, filter(x -> isdual(space(B, x)), allind(B)))
4646
dA = tensorcontract!(dA, ΔC, pΔC, false, tB, pB, true, ipA)
4747
return projectA(dA)
4848
end
4949
dB_ = @thunk let
5050
ipB = (codomainind(B), domainind(B))
5151
pA = ((), allind(A))
5252
dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A)))
53-
tA = twist(A, filter(x -> isdual(space(A, x)), allind(A)))
53+
tA = _twist_nocopy(A, filter(x -> isdual(space(A, x)), allind(A)))
5454
dB = tensorcontract!(dB, tA, pA, true, ΔC, pΔC, false, ipB)
5555
return projectB(dB)
5656
end

ext/TensorKitChainRulesCoreExt/tensoroperations.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function ChainRulesCore.rrule(
2828
# for non-symmetric tensors this might be more efficient like this,
2929
# but for symmetric tensors an intermediate object will anyways be created
3030
# and then it might be more efficient to use an addition and inner product
31-
tΔC = twist(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC)))
31+
tΔC = _twist_nocopy(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC)))
3232
_dα = tensorscalar(
3333
tensorcontract(
3434
A, ((), linearize(pA)), !conjA,
@@ -74,7 +74,7 @@ function ChainRulesCore.rrule(
7474
conjB′ = conjA ? conjB : !conjB
7575
TA = promote_contract(scalartype(ΔC), scalartype(B), scalartype(α))
7676
# TODO: allocator
77-
tB = twist(
77+
tB = _twist_nocopy(
7878
B,
7979
TupleTools.vcat(
8080
filter(x -> !isdual(space(B, x)), pB[1]),
@@ -99,7 +99,7 @@ function ChainRulesCore.rrule(
9999
conjA′ = conjB ? conjA : !conjA
100100
TB = promote_contract(scalartype(ΔC), scalartype(A), scalartype(α))
101101
# TODO: allocator
102-
tA = twist(
102+
tA = _twist_nocopy(
103103
A,
104104
TupleTools.vcat(
105105
filter(x -> isdual(space(A, x)), pA[1]),
@@ -188,3 +188,10 @@ function ChainRulesCore.rrule(::typeof(TensorKit.scalar), t::AbstractTensorMap)
188188
end
189189
return val, scalar_pullback
190190
end
191+
192+
# temporary function to avoid copies when not needed
193+
# TODO: remove once `twist(t; copy=false)` is defined
194+
function _twist_nocopy(t, inds; kwargs...)
195+
(BraidingStyle(sectortype(t)) isa Bosonic || isempty(inds)) && return t
196+
return twist(t, inds; kwargs...)
197+
end

0 commit comments

Comments
 (0)