Skip to content

Commit 80ade51

Browse files
authored
fix ad type stability (#198)
1 parent 582b6d7 commit 80ade51

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

ext/TensorKitChainRulesCoreExt/linalg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,15 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe
3737
pΔC = ((codomainind(A)..., (domainind(A) .+ numout(B))...),
3838
((codomainind(B) .+ numout(A))...,
3939
(domainind(B) .+ (numin(A) + numout(A)))...))
40-
dA_ = @thunk begin
40+
dA_ = @thunk let
4141
ipA = (codomainind(A), domainind(A))
4242
pB = (allind(B), ())
4343
dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B)))
4444
tB = twist(B, filter(x -> isdual(space(B, x)), allind(B)))
4545
dA = tensorcontract!(dA, ΔC, pΔC, false, tB, pB, true, ipA)
4646
return projectA(dA)
4747
end
48-
dB_ = @thunk begin
48+
dB_ = @thunk let
4949
ipB = (codomainind(B), domainind(B))
5050
pA = ((), allind(A))
5151
dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A)))

ext/TensorKitChainRulesCoreExt/tensoroperations.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
1212
function pullback(ΔC′)
1313
ΔC = unthunk(ΔC′)
1414
dC = @thunk projectC(scale(ΔC, conj(β)))
15-
dA = @thunk begin
15+
dA = @thunk let
1616
ipA = invperm(linearize(pA))
1717
_dA = zerovector(A, promote_add(ΔC, α))
1818
_dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), Zero(), ba...)
1919
return projectA(_dA)
2020
end
21-
= @thunk begin
21+
= @thunk let
2222
# TODO: this is an inner product implemented as a contraction
2323
# for non-symmetric tensors this might be more efficient like this,
2424
# but for symmetric tensors an intermediate object will anyways be created
@@ -59,7 +59,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
5959
TupleTools.getindices(ipAB, TO.numout(pA) .+ trivtuple(TO.numin(pB))))
6060

6161
dC = @thunk projectC(scale(ΔC, conj(β)))
62-
dA = @thunk begin
62+
dA = @thunk let
6363
ipA = (invperm(linearize(pA)), ())
6464
conjΔC = conjA
6565
conjB′ = conjA ? conjB : !conjB
@@ -74,7 +74,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
7474
conjA ? α : conj(α), Zero(), ba...)
7575
return projectA(_dA)
7676
end
77-
dB = @thunk begin
77+
dB = @thunk let
7878
ipB = (invperm(linearize(pB)), ())
7979
conjΔC = conjB
8080
conjA′ = conjB ? conjA : !conjA
@@ -89,7 +89,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
8989
conjB ? α : conj(α), Zero(), ba...)
9090
return projectB(_dB)
9191
end
92-
= @thunk begin
92+
= @thunk let
9393
# TODO: this result should be AB = (C′ - βC) / α as C′ = βC + αAB
9494
AB = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
9595
return projectα(inner(AB, ΔC))
@@ -119,7 +119,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensortrace!),
119119
function pullback(ΔC′)
120120
ΔC = unthunk(ΔC′)
121121
dC = @thunk projectC(scale(ΔC, conj(β)))
122-
dA = @thunk begin
122+
dA = @thunk let
123123
ip = invperm((linearize(p)..., q[1]..., q[2]...))
124124
E = one!(TO.tensoralloc_add(scalartype(A), A, q, conjA))
125125
twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E)))
@@ -130,7 +130,7 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensortrace!),
130130
conjA ? α : conj(α), Zero(), ba...)
131131
return projectA(_dA)
132132
end
133-
= @thunk begin
133+
= @thunk let
134134
# TODO: this result might be easier to compute as:
135135
# C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α
136136
At = tensortrace(A, p, q, conjA)

0 commit comments

Comments
 (0)