@@ -61,7 +61,8 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
6161 pΔC = _repartition (ipAB, TO. numout (pA))
6262
6363 dC = @thunk projectC (scale (ΔC, conj (β)))
64- dA = @thunk let ipA = _repartition (invperm (linearize (pA)), A)
64+ dA = @thunk let
65+ ipA = _repartition (invperm (linearize (pA)), A)
6566 conjΔC = conjA
6667 conjB′ = conjA ? conjB : ! conjB
6768 TA = promote_contract (scalartype (ΔC), scalartype (B), scalartype (α))
@@ -77,7 +78,8 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
7778 conjA ? α : conj (α), Zero (), ba... )
7879 return projectA (_dA)
7980 end
80- dB = @thunk let ipB = _repartition (invperm (linearize (pB)), B)
81+ dB = @thunk let
82+ ipB = _repartition (invperm (linearize (pB)), B)
8183 conjΔC = conjB
8284 conjA′ = conjB ? conjA : ! conjA
8385 TB = promote_contract (scalartype (ΔC), scalartype (A), scalartype (α))
@@ -123,7 +125,9 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensortrace!),
123125 function pullback (ΔC′)
124126 ΔC = unthunk (ΔC′)
125127 dC = @thunk projectC (scale (ΔC, conj (β)))
126- dA = @thunk let ip = invperm ((linearize (p)... , q[1 ]. .. , q[2 ]. .. )), pdA = _repartition (ip, A)
128+ dA = @thunk let
129+ ip = invperm ((linearize (p)... , q[1 ]. .. , q[2 ]. .. ))
130+ pdA = _repartition (ip, A)
127131 E = one! (TO. tensoralloc_add (scalartype (A), A, q, conjA))
128132 twist! (E, filter (x -> ! isdual (space (E, x)), codomainind (E)))
129133 pE = ((), trivtuple (TO. numind (q)))
0 commit comments