@@ -12,13 +12,13 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
1212 function pullback (ΔC′)
1313 ΔC = unthunk (ΔC′)
1414 dC = @thunk projectC (scale (ΔC, conj (β)))
15- dA = @thunk begin
15+ dA = @thunk let
1616 ipA = invperm (linearize (pA))
1717 _dA = zerovector (A, promote_add (ΔC, α))
1818 _dA = tensoradd! (_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj (α), Zero (), ba... )
1919 return projectA (_dA)
2020 end
21- dα = @thunk begin
21+ dα = @thunk let
2222 # TODO : this is an inner product implemented as a contraction
2323 # for non-symmetric tensors this might be more efficient like this,
2424 # but for symmetric tensors an intermediate object will anyways be created
@@ -59,7 +59,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
5959 TupleTools. getindices (ipAB, TO. numout (pA) .+ trivtuple (TO. numin (pB))))
6060
6161 dC = @thunk projectC (scale (ΔC, conj (β)))
62- dA = @thunk begin
62+ dA = @thunk let
6363 ipA = (invperm (linearize (pA)), ())
6464 conjΔC = conjA
6565 conjB′ = conjA ? conjB : ! conjB
@@ -74,7 +74,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
7474 conjA ? α : conj (α), Zero (), ba... )
7575 return projectA (_dA)
7676 end
77- dB = @thunk begin
77+ dB = @thunk let
7878 ipB = (invperm (linearize (pB)), ())
7979 conjΔC = conjB
8080 conjA′ = conjB ? conjA : ! conjA
@@ -89,7 +89,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
8989 conjB ? α : conj (α), Zero (), ba... )
9090 return projectB (_dB)
9191 end
92- dα = @thunk begin
92+ dα = @thunk let
9393 # TODO : this result should be AB = (C′ - βC) / α as C′ = βC + αAB
9494 AB = tensorcontract (A, pA, conjA, B, pB, conjB, pAB, One (), ba... )
9595 return projectα (inner (AB, ΔC))
@@ -119,7 +119,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensortrace!),
119119 function pullback (ΔC′)
120120 ΔC = unthunk (ΔC′)
121121 dC = @thunk projectC (scale (ΔC, conj (β)))
122- dA = @thunk begin
122+ dA = @thunk let
123123 ip = invperm ((linearize (p)... , q[1 ]. .. , q[2 ]. .. ))
124124 E = one! (TO. tensoralloc_add (scalartype (A), A, q, conjA))
125125 twist! (E, filter (x -> ! isdual (space (E, x)), codomainind (E)))
@@ -130,7 +130,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensortrace!),
130130 conjA ? α : conj (α), Zero (), ba... )
131131 return projectA (_dA)
132132 end
133- dα = @thunk begin
133+ dα = @thunk let
134134 # TODO : this result might be easier to compute as:
135135 # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α
136136 At = tensortrace (A, p, q, conjA)
0 commit comments