Skip to content

Commit faaa89b

Browse files
authored
try fix (without testing)
1 parent d0205de commit faaa89b

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

ext/TensorKitChainRulesCoreExt/tensoroperations.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
6161
pΔC = _repartition(ipAB, TO.numout(pA))
6262

6363
dC = @thunk projectC(scale(ΔC, conj(β)))
64-
dA = @thunk let
65-
ipA = _repartition(invperm(linearize(pA)), A)
64+
dA = @thunk let ipA = _repartition(invperm(linearize(pA)), A)
6665
conjΔC = conjA
6766
conjB′ = conjA ? conjB : !conjB
6867
TA = promote_contract(scalartype(ΔC), scalartype(B), scalartype(α))
@@ -78,8 +77,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
7877
conjA ? α : conj(α), Zero(), ba...)
7978
return projectA(_dA)
8079
end
81-
dB = @thunk let
82-
ipB = _repartition(invperm(linearize(pB)), B)
80+
dB = @thunk let ipB = _repartition(invperm(linearize(pB)), B)
8381
conjΔC = conjB
8482
conjA′ = conjB ? conjA : !conjA
8583
TB = promote_contract(scalartype(ΔC), scalartype(A), scalartype(α))
@@ -125,9 +123,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensortrace!),
125123
function pullback(ΔC′)
126124
ΔC = unthunk(ΔC′)
127125
dC = @thunk projectC(scale(ΔC, conj(β)))
128-
dA = @thunk let
129-
ip = invperm((linearize(p)..., q[1]..., q[2]...))
130-
pdA = _repartition(ip, A)
126+
dA = @thunk let ip = invperm((linearize(p)..., q[1]..., q[2]...)), pdA = _repartition(ip, A)
131127
E = one!(TO.tensoralloc_add(scalartype(A), A, q, conjA))
132128
twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E)))
133129
pE = ((), trivtuple(TO.numind(q)))

0 commit comments

Comments
 (0)