Skip to content

Commit 591b3fb

Browse files
committed
Add nocopy twists in tensorproduct
1 parent 4852920 commit 591b3fb

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

ext/TensorKitChainRulesCoreExt/linalg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,15 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe
4242
ipA = (codomainind(A), domainind(A))
4343
pB = (allind(B), ())
4444
dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B)))
45-
tB = twist(B, filter(x -> isdual(space(B, x)), allind(B)))
45+
tB = _twist_nocopy(B, filter(x -> isdual(space(B, x)), allind(B)))
4646
dA = tensorcontract!(dA, ΔC, pΔC, false, tB, pB, true, ipA)
4747
return projectA(dA)
4848
end
4949
dB_ = @thunk let
5050
ipB = (codomainind(B), domainind(B))
5151
pA = ((), allind(A))
5252
dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A)))
53-
tA = twist(A, filter(x -> isdual(space(A, x)), allind(A)))
53+
tA = _twist_nocopy(A, filter(x -> isdual(space(A, x)), allind(A)))
5454
dB = tensorcontract!(dB, tA, pA, true, ΔC, pΔC, false, ipB)
5555
return projectB(dB)
5656
end

ext/TensorKitChainRulesCoreExt/tensoroperations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ end
191191

192192
# temporary function to avoid copies when not needed
193193
# TODO: remove once `twist(t; copy=false)` is defined
194-
function _twist_nocopy(t, inds)
194+
function _twist_nocopy(t, inds; kwargs...)
195195
(BraidingStyle(sectortype(t)) isa Fermionic && !isempty(inds)) || return t
196-
return twist(t, inds)
196+
return twist(t, inds; kwargs...)
197197
end

0 commit comments

Comments
 (0)