Skip to content

Commit dfe4026

Browse files
authored
TensorOperations AD clean-up (#343)
* TensorOperations chainrules optimizations for alpha and beta cases * remove temporary `_twist_nocopy`
1 parent 8b70a79 commit dfe4026

File tree

2 files changed

+57
-45
lines changed

2 files changed

+57
-45
lines changed

ext/TensorKitChainRulesCoreExt/linalg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,15 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe
4242
ipA = (codomainind(A), domainind(A))
4343
pB = (allind(B), ())
4444
dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B)))
45-
tB = _twist_nocopy(B, filter(x -> isdual(space(B, x)), allind(B)))
45+
tB = twist(B, filter(x -> isdual(space(B, x)), allind(B)); copy = false)
4646
dA = tensorcontract!(dA, ΔC, pΔC, false, tB, pB, true, ipA)
4747
return projectA(dA)
4848
end
4949
dB_ = @thunk let
5050
ipB = (codomainind(B), domainind(B))
5151
pA = ((), allind(A))
5252
dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A)))
53-
tA = _twist_nocopy(A, filter(x -> isdual(space(A, x)), allind(A)))
53+
tA = twist(A, filter(x -> isdual(space(A, x)), allind(A)); copy = false)
5454
dB = tensorcontract!(dB, tA, pA, true, ΔC, pΔC, false, ipB)
5555
return projectB(dB)
5656
end

ext/TensorKitChainRulesCoreExt/tensoroperations.jl

Lines changed: 55 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
# To avoid computing rrules for α and β when these aren't needed, we want to have a
2+
# type-stable quick bail-out
3+
_needs_tangent(x) = _needs_tangent(typeof(x))
4+
_needs_tangent(::Type{<:Number}) = true
5+
_needs_tangent(::Type{<:Integer}) = false
6+
_needs_tangent(::Type{<:Union{One, Zero}}) = false
7+
18
function ChainRulesCore.rrule(
29
::typeof(TensorOperations.tensoradd!),
310
C::AbstractTensorMap,
@@ -13,32 +20,36 @@ function ChainRulesCore.rrule(
1320

1421
function pullback(ΔC′)
1522
ΔC = unthunk(ΔC′)
16-
dC = @thunk projectC(scale(ΔC, conj(β)))
23+
dC = β === Zero() ? ZeroTangent() : @thunk projectC(scale(ΔC, conj(β)))
1724
dA = @thunk let
1825
ipA = invperm(linearize(pA))
1926
pdA = _repartition(ipA, A)
2027
TA = promote_add(ΔC, α)
2128
# TODO: allocator
2229
_dA = tensoralloc_add(TA, ΔC, pdA, conjA, Val(false))
2330
_dA = tensoradd!(_dA, ΔC, pdA, conjA, conjA ? α : conj(α), Zero(), ba...)
24-
return projectA(_dA)
31+
projectA(_dA)
2532
end
26-
= @thunk let
27-
# TODO: this is an inner product implemented as a contraction
28-
# for non-symmetric tensors this might be more efficient like this,
29-
# but for symmetric tensors an intermediate object will anyways be created
30-
# and then it might be more efficient to use an addition and inner product
31-
tΔC = _twist_nocopy(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC)))
32-
_dα = tensorscalar(
33-
tensorcontract(
34-
A, ((), linearize(pA)), !conjA,
35-
tΔC, (trivtuple(TO.numind(pA)), ()), false,
36-
((), ()), One(), ba...
33+
= if _needs_tangent(α)
34+
@thunk let
35+
# TODO: this is an inner product implemented as a contraction
36+
# for non-symmetric tensors this might be more efficient like this,
37+
# but for symmetric tensors an intermediate object will anyways be created
38+
# and then it might be more efficient to use an addition and inner product
39+
tΔC = twist(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC)); copy = false)
40+
_dα = tensorscalar(
41+
tensorcontract(
42+
A, ((), linearize(pA)), !conjA,
43+
tΔC, (trivtuple(TO.numind(pA)), ()), false,
44+
((), ()), One(), ba...
45+
)
3746
)
38-
)
39-
return projectα(_dα)
47+
projectα(_dα)
48+
end
49+
else
50+
ZeroTangent()
4051
end
41-
= @thunk projectβ(inner(C, ΔC))
52+
= _needs_tangent(β) ? @thunk(projectβ(inner(C, ΔC))) : ZeroTangent()
4253
dba = map(_ -> NoTangent(), ba)
4354
return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ, dba...
4455
end
@@ -67,19 +78,19 @@ function ChainRulesCore.rrule(
6778
ipAB = invperm(linearize(pAB))
6879
pΔC = _repartition(ipAB, TO.numout(pA))
6980

70-
dC = @thunk projectC(scale(ΔC, conj(β)))
81+
dC = β === Zero() ? ZeroTangent() : @thunk projectC(scale(ΔC, conj(β)))
7182
dA = @thunk let
7283
ipA = _repartition(invperm(linearize(pA)), A)
7384
conjΔC = conjA
7485
conjB′ = conjA ? conjB : !conjB
7586
TA = promote_contract(scalartype(ΔC), scalartype(B), scalartype(α))
7687
# TODO: allocator
77-
tB = _twist_nocopy(
88+
tB = twist(
7889
B,
7990
TupleTools.vcat(
8091
filter(x -> !isdual(space(B, x)), pB[1]),
8192
filter(x -> isdual(space(B, x)), pB[2])
82-
)
93+
); copy = false
8394
)
8495
_dA = tensoralloc_contract(
8596
TA, ΔC, pΔC, conjΔC, tB, reverse(pB), conjB′, ipA, Val(false)
@@ -91,20 +102,20 @@ function ChainRulesCore.rrule(
91102
ipA,
92103
conjA ? α : conj(α), Zero(), ba...
93104
)
94-
return projectA(_dA)
105+
projectA(_dA)
95106
end
96107
dB = @thunk let
97108
ipB = _repartition(invperm(linearize(pB)), B)
98109
conjΔC = conjB
99110
conjA′ = conjB ? conjA : !conjA
100111
TB = promote_contract(scalartype(ΔC), scalartype(A), scalartype(α))
101112
# TODO: allocator
102-
tA = _twist_nocopy(
113+
tA = twist(
103114
A,
104115
TupleTools.vcat(
105116
filter(x -> isdual(space(A, x)), pA[1]),
106117
filter(x -> !isdual(space(A, x)), pA[2])
107-
)
118+
); copy = false
108119
)
109120
_dB = tensoralloc_contract(
110121
TB, tA, reverse(pA), conjA′, ΔC, pΔC, conjΔC, ipB, Val(false)
@@ -116,14 +127,18 @@ function ChainRulesCore.rrule(
116127
ipB,
117128
conjB ? α : conj(α), Zero(), ba...
118129
)
119-
return projectB(_dB)
130+
projectB(_dB)
120131
end
121-
= @thunk let
122-
# TODO: this result should be AB = (C′ - βC) / α as C′ = βC + αAB
123-
AB = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
124-
return projectα(inner(AB, ΔC))
132+
= if _needs_tangent(α)
133+
@thunk let
134+
# TODO: this result should be AB = (C′ - βC) / α as C′ = βC + αAB
135+
AB = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
136+
projectα(inner(AB, ΔC))
137+
end
138+
else
139+
ZeroTangent()
125140
end
126-
= @thunk projectβ(inner(C, ΔC))
141+
= _needs_tangent(β) ? @thunk(projectβ(inner(C, ΔC))) : ZeroTangent()
127142
dba = map(_ -> NoTangent(), ba)
128143
return NoTangent(), dC,
129144
dA, NoTangent(), NoTangent(),
@@ -149,7 +164,7 @@ function ChainRulesCore.rrule(
149164

150165
function pullback(ΔC′)
151166
ΔC = unthunk(ΔC′)
152-
dC = @thunk projectC(scale(ΔC, conj(β)))
167+
dC = β === Zero() ? ZeroTangent() : @thunk projectC(scale(ΔC, conj(β)))
153168
dA = @thunk let
154169
ip = invperm((linearize(p)..., q[1]..., q[2]...))
155170
pdA = _repartition(ip, A)
@@ -163,15 +178,19 @@ function ChainRulesCore.rrule(
163178
_dA = tensorproduct!(
164179
_dA, ΔC, pΔC, conjA, E, pE, conjA, pdA, conjA ? α : conj(α), Zero(), ba...
165180
)
166-
return projectA(_dA)
181+
projectA(_dA)
167182
end
168-
= @thunk let
169-
# TODO: this result might be easier to compute as:
170-
# C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α
171-
At = tensortrace(A, p, q, conjA)
172-
return projectα(inner(At, ΔC))
183+
= if _needs_tangent(α)
184+
@thunk let
185+
# TODO: this result might be easier to compute as:
186+
# C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α
187+
At = tensortrace(A, p, q, conjA)
188+
projectα(inner(At, ΔC))
189+
end
190+
else
191+
ZeroTangent()
173192
end
174-
= @thunk projectβ(inner(C, ΔC))
193+
= _needs_tangent(β) ? @thunk(projectβ(inner(C, ΔC))) : ZeroTangent()
175194
dba = map(_ -> NoTangent(), ba)
176195
return NoTangent(), dC, dA, NoTangent(), NoTangent(), NoTangent(), dα, dβ, dba...
177196
end
@@ -188,10 +207,3 @@ function ChainRulesCore.rrule(::typeof(TensorKit.scalar), t::AbstractTensorMap)
188207
end
189208
return val, scalar_pullback
190209
end
191-
192-
# temporary function to avoid copies when not needed
193-
# TODO: remove once `twist(t; copy=false)` is defined
194-
function _twist_nocopy(t, inds; kwargs...)
195-
(BraidingStyle(sectortype(t)) isa Bosonic || isempty(inds)) && return t
196-
return twist(t, inds; kwargs...)
197-
end

0 commit comments

Comments
 (0)