@@ -28,7 +28,7 @@ function ChainRulesCore.rrule(
2828 # for non-symmetric tensors this might be more efficient like this,
2929 # but for symmetric tensors an intermediate object will anyways be created
3030 # and then it might be more efficient to use an addition and inner product
31- tΔC = twist (ΔC, filter (x -> isdual (space (ΔC, x)), allind (ΔC)))
31+ tΔC = _twist_nocopy (ΔC, filter (x -> isdual (space (ΔC, x)), allind (ΔC)))
3232 _dα = tensorscalar (
3333 tensorcontract (
3434 A, ((), linearize (pA)), ! conjA,
@@ -74,7 +74,7 @@ function ChainRulesCore.rrule(
7474 conjB′ = conjA ? conjB : ! conjB
7575 TA = promote_contract (scalartype (ΔC), scalartype (B), scalartype (α))
7676 # TODO : allocator
77- tB = twist (
77+ tB = _twist_nocopy (
7878 B,
7979 TupleTools. vcat (
8080 filter (x -> ! isdual (space (B, x)), pB[1 ]),
@@ -99,7 +99,7 @@ function ChainRulesCore.rrule(
9999 conjA′ = conjB ? conjA : ! conjA
100100 TB = promote_contract (scalartype (ΔC), scalartype (A), scalartype (α))
101101 # TODO : allocator
102- tA = twist (
102+ tA = _twist_nocopy (
103103 A,
104104 TupleTools. vcat (
105105 filter (x -> isdual (space (A, x)), pA[1 ]),
@@ -188,3 +188,10 @@ function ChainRulesCore.rrule(::typeof(TensorKit.scalar), t::AbstractTensorMap)
188188 end
189189 return val, scalar_pullback
190190end
191+
192+ # temporary function to avoid copies when not needed
193+ # TODO : remove once `twist(t; copy=false)` is defined
194+ function _twist_nocopy (t, inds)
195+ (BraidingStyle (sectortype (t)) isa Fermionic && ! isempty (inds)) || return t
196+ return twist (t, inds)
197+ end
0 commit comments