Skip to content

Commit 4852920

Browse files
committed
avoid twist copies in rrules
1 parent 9ba3b6c commit 4852920

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

ext/TensorKitChainRulesCoreExt/tensoroperations.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function ChainRulesCore.rrule(
2828
# for non-symmetric tensors this might be more efficient like this,
2929
# but for symmetric tensors an intermediate object will anyways be created
3030
# and then it might be more efficient to use an addition and inner product
31-
tΔC = twist(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC)))
31+
tΔC = _twist_nocopy(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC)))
3232
_dα = tensorscalar(
3333
tensorcontract(
3434
A, ((), linearize(pA)), !conjA,
@@ -74,7 +74,7 @@ function ChainRulesCore.rrule(
7474
conjB′ = conjA ? conjB : !conjB
7575
TA = promote_contract(scalartype(ΔC), scalartype(B), scalartype(α))
7676
# TODO: allocator
77-
tB = twist(
77+
tB = _twist_nocopy(
7878
B,
7979
TupleTools.vcat(
8080
filter(x -> !isdual(space(B, x)), pB[1]),
@@ -99,7 +99,7 @@ function ChainRulesCore.rrule(
9999
conjA′ = conjB ? conjA : !conjA
100100
TB = promote_contract(scalartype(ΔC), scalartype(A), scalartype(α))
101101
# TODO: allocator
102-
tA = twist(
102+
tA = _twist_nocopy(
103103
A,
104104
TupleTools.vcat(
105105
filter(x -> isdual(space(A, x)), pA[1]),
@@ -188,3 +188,10 @@ function ChainRulesCore.rrule(::typeof(TensorKit.scalar), t::AbstractTensorMap)
188188
end
189189
return val, scalar_pullback
190190
end
191+
192+
# temporary function to avoid copies when not needed
193+
# TODO: remove once `twist(t; copy=false)` is defined
194+
function _twist_nocopy(t, inds)
195+
(BraidingStyle(sectortype(t)) isa Fermionic && !isempty(inds)) || return t
196+
return twist(t, inds)
197+
end

0 commit comments

Comments
 (0)