@@ -86,19 +86,19 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
8686 function pullback (ΔC′)
8787 ΔC = unthunk (ΔC′)
8888 dC = @thunk projectC (scale (ΔC, conj (β)))
89- dA = @thunk begin
89+ dA = @thunk let
9090 ipA = invperm (linearize (pA))
9191 _dA = zerovector (A, VectorInterface. promote_add (ΔC, α))
9292 _dA = tensoradd! (_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj (α), Zero (), ba... )
9393 return projectA (_dA)
9494 end
95- dα = @thunk begin
95+ dα = @thunk let
9696 _dα = tensorscalar (tensorcontract (A, ((), linearize (pA)), ! conjA,
9797 ΔC, (trivtuple (numind (pA)), ()), false ,
9898 ((), ()), One (), ba... ))
9999 return projectα (_dα)
100100 end
101- dβ = @thunk begin
101+ dβ = @thunk let
102102 # TODO : consider using `inner`
103103 _dβ = tensorscalar (tensorcontract (C, ((), trivtuple (numind (pA))), true ,
104104 ΔC, (trivtuple (numind (pA)), ()), false ,
@@ -165,7 +165,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
165165 pΔC = (TupleTools. getindices (ipAB, trivtuple (numout (pA))),
166166 TupleTools. getindices (ipAB, numout (pA) .+ trivtuple (numin (pB))))
167167 dC = @thunk projectC (scale (ΔC, conj (β)))
168- dA = @thunk begin
168+ dA = @thunk let
169169 ipA = (invperm (linearize (pA)), ())
170170 conjΔC = conjA
171171 conjB′ = conjA ? conjB : ! conjB
@@ -177,7 +177,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
177177 conjA ? α : conj (α), Zero (), ba... )
178178 return projectA (_dA)
179179 end
180- dB = @thunk begin
180+ dB = @thunk let
181181 ipB = (invperm (linearize (pB)), ())
182182 conjΔC = conjB
183183 conjA′ = conjB ? conjA : ! conjA
@@ -189,15 +189,15 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
189189 conjB ? α : conj (α), Zero (), ba... )
190190 return projectB (_dB)
191191 end
192- dα = @thunk begin
192+ dα = @thunk let
193193 C_αβ = tensorcontract (A, pA, conjA, B, pB, conjB, pAB, One (), ba... )
194194 # TODO : consider using `inner`
195195 _dα = tensorscalar (tensorcontract (C_αβ, ((), trivtuple (numind (pAB))), true ,
196196 ΔC, (trivtuple (numind (pAB)), ()), false ,
197197 ((), ()), One (), ba... ))
198198 return projectα (_dα)
199199 end
200- dβ = @thunk begin
200+ dβ = @thunk let
201201 # TODO : consider using `inner`
202202 _dβ = tensorscalar (tensorcontract (C, ((), trivtuple (numind (pAB))), true ,
203203 ΔC, (trivtuple (numind (pAB)), ()), false ,
@@ -249,7 +249,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
249249 function pullback (ΔC′)
250250 ΔC = unthunk (ΔC′)
251251 dC = @thunk projectC (scale (ΔC, conj (β)))
252- dA = @thunk begin
252+ dA = @thunk let
253253 ip = invperm ((linearize (p)... , q[1 ]. .. , q[2 ]. .. ))
254254 Es = map (q[1 ], q[2 ]) do i1, i2
255255 return one (TensorOperations. tensoralloc_add (scalartype (A), A,
@@ -263,15 +263,15 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
263263 conjA ? α : conj (α), Zero (), ba... )
264264 return projectA (_dA)
265265 end
266- dα = @thunk begin
266+ dα = @thunk let
267267 C_αβ = tensortrace (A, p, q, false , One (), ba... )
268268 _dα = tensorscalar (tensorcontract (C_αβ, ((), trivtuple (numind (p))),
269269 ! conjA,
270270 ΔC, (trivtuple (numind (p)), ()), false ,
271271 ((), ()), One (), ba... ))
272272 return projectα (_dα)
273273 end
274- dβ = @thunk begin
274+ dβ = @thunk let
275275 _dβ = tensorscalar (tensorcontract (C, ((), trivtuple (numind (p))), true ,
276276 ΔC, (trivtuple (numind (p)), ()), false ,
277277 ((), ()), One (), ba... ))
0 commit comments