Skip to content

Commit 7bbc830

Browse files
committed
remove temporary _twist_nocopy
1 parent fab0dc8 commit 7bbc830

File tree

2 files changed

+7
-14
lines changed

2 files changed

+7
-14
lines changed

ext/TensorKitChainRulesCoreExt/linalg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,15 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe
4242
ipA = (codomainind(A), domainind(A))
4343
pB = (allind(B), ())
4444
dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B)))
45-
tB = _twist_nocopy(B, filter(x -> isdual(space(B, x)), allind(B)))
45+
tB = twist(B, filter(x -> isdual(space(B, x)), allind(B)); copy = false)
4646
dA = tensorcontract!(dA, ΔC, pΔC, false, tB, pB, true, ipA)
4747
return projectA(dA)
4848
end
4949
dB_ = @thunk let
5050
ipB = (codomainind(B), domainind(B))
5151
pA = ((), allind(A))
5252
dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A)))
53-
tA = _twist_nocopy(A, filter(x -> isdual(space(A, x)), allind(A)))
53+
tA = twist(A, filter(x -> isdual(space(A, x)), allind(A)); copy = false)
5454
dB = tensorcontract!(dB, tA, pA, true, ΔC, pΔC, false, ipB)
5555
return projectB(dB)
5656
end

ext/TensorKitChainRulesCoreExt/tensoroperations.jl

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function ChainRulesCore.rrule(
3636
# for non-symmetric tensors this might be more efficient like this,
3737
# but for symmetric tensors an intermediate object will anyways be created
3838
# and then it might be more efficient to use an addition and inner product
39-
tΔC = _twist_nocopy(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC)))
39+
tΔC = twist(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC)); copy = false)
4040
_dα = tensorscalar(
4141
tensorcontract(
4242
A, ((), linearize(pA)), !conjA,
@@ -85,12 +85,12 @@ function ChainRulesCore.rrule(
8585
conjB′ = conjA ? conjB : !conjB
8686
TA = promote_contract(scalartype(ΔC), scalartype(B), scalartype(α))
8787
# TODO: allocator
88-
tB = _twist_nocopy(
88+
tB = twist(
8989
B,
9090
TupleTools.vcat(
9191
filter(x -> !isdual(space(B, x)), pB[1]),
9292
filter(x -> isdual(space(B, x)), pB[2])
93-
)
93+
); copy = false
9494
)
9595
_dA = tensoralloc_contract(
9696
TA, ΔC, pΔC, conjΔC, tB, reverse(pB), conjB′, ipA, Val(false)
@@ -110,12 +110,12 @@ function ChainRulesCore.rrule(
110110
conjA′ = conjB ? conjA : !conjA
111111
TB = promote_contract(scalartype(ΔC), scalartype(A), scalartype(α))
112112
# TODO: allocator
113-
tA = _twist_nocopy(
113+
tA = twist(
114114
A,
115115
TupleTools.vcat(
116116
filter(x -> isdual(space(A, x)), pA[1]),
117117
filter(x -> !isdual(space(A, x)), pA[2])
118-
)
118+
); copy = false
119119
)
120120
_dB = tensoralloc_contract(
121121
TB, tA, reverse(pA), conjA′, ΔC, pΔC, conjΔC, ipB, Val(false)
@@ -207,10 +207,3 @@ function ChainRulesCore.rrule(::typeof(TensorKit.scalar), t::AbstractTensorMap)
207207
end
208208
return val, scalar_pullback
209209
end
210-
211-
# temporary function to avoid copies when not needed
212-
# TODO: remove once `twist(t; copy=false)` is defined
213-
function _twist_nocopy(t, inds; kwargs...)
214-
(BraidingStyle(sectortype(t)) isa Bosonic || isempty(inds)) && return t
215-
return twist(t, inds; kwargs...)
216-
end

0 commit comments

Comments
 (0)