@@ -61,8 +61,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
6161 pΔC = _repartition (ipAB, TO. numout (pA))
6262
6363 dC = @thunk projectC (scale (ΔC, conj (β)))
64- dA = @thunk let
65- ipA = _repartition (invperm (linearize (pA)), A)
64+ dA = @thunk let ipA = _repartition (invperm (linearize (pA)), A)
6665 conjΔC = conjA
6766 conjB′ = conjA ? conjB : ! conjB
6867 TA = promote_contract (scalartype (ΔC), scalartype (B), scalartype (α))
@@ -78,8 +77,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
7877 conjA ? α : conj (α), Zero (), ba... )
7978 return projectA (_dA)
8079 end
81- dB = @thunk let
82- ipB = _repartition (invperm (linearize (pB)), B)
80+ dB = @thunk let ipB = _repartition (invperm (linearize (pB)), B)
8381 conjΔC = conjB
8482 conjA′ = conjB ? conjA : ! conjA
8583 TB = promote_contract (scalartype (ΔC), scalartype (A), scalartype (α))
@@ -125,9 +123,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensortrace!),
125123 function pullback (ΔC′)
126124 ΔC = unthunk (ΔC′)
127125 dC = @thunk projectC (scale (ΔC, conj (β)))
128- dA = @thunk let
129- ip = invperm ((linearize (p)... , q[1 ]. .. , q[2 ]. .. ))
130- pdA = _repartition (ip, A)
126+ dA = @thunk let ip = invperm ((linearize (p)... , q[1 ]. .. , q[2 ]. .. )), pdA = _repartition (ip, A)
131127 E = one! (TO. tensoralloc_add (scalartype (A), A, q, conjA))
132128 twist! (E, filter (x -> ! isdual (space (E, x)), codomainind (E)))
133129 pE = ((), trivtuple (TO. numind (q)))
0 commit comments