Skip to content

Commit fab0dc8

Browse files
committed
TensorOperations chainrules optimizations for alpha and beta cases
1 parent 8b70a79 commit fab0dc8

File tree

1 file changed

+51
-32
lines changed

1 file changed

+51
-32
lines changed

ext/TensorKitChainRulesCoreExt/tensoroperations.jl

Lines changed: 51 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
# To avoid computing rrules for α and β when these aren't needed, we want to have a
2+
# type-stable quick bail-out
3+
_needs_tangent(x) = _needs_tangent(typeof(x))
4+
_needs_tangent(::Type{<:Number}) = true
5+
_needs_tangent(::Type{<:Integer}) = false
6+
_needs_tangent(::Type{<:Union{One, Zero}}) = false
7+
18
function ChainRulesCore.rrule(
29
::typeof(TensorOperations.tensoradd!),
310
C::AbstractTensorMap,
@@ -13,32 +20,36 @@ function ChainRulesCore.rrule(
1320

1421
function pullback(ΔC′)
1522
ΔC = unthunk(ΔC′)
16-
dC = @thunk projectC(scale(ΔC, conj(β)))
23+
dC = β === Zero() ? ZeroTangent() : @thunk projectC(scale(ΔC, conj(β)))
1724
dA = @thunk let
1825
ipA = invperm(linearize(pA))
1926
pdA = _repartition(ipA, A)
2027
TA = promote_add(ΔC, α)
2128
# TODO: allocator
2229
_dA = tensoralloc_add(TA, ΔC, pdA, conjA, Val(false))
2330
_dA = tensoradd!(_dA, ΔC, pdA, conjA, conjA ? α : conj(α), Zero(), ba...)
24-
return projectA(_dA)
31+
projectA(_dA)
2532
end
26-
= @thunk let
27-
# TODO: this is an inner product implemented as a contraction
28-
# for non-symmetric tensors this might be more efficient like this,
29-
# but for symmetric tensors an intermediate object will anyways be created
30-
# and then it might be more efficient to use an addition and inner product
31-
tΔC = _twist_nocopy(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC)))
32-
_dα = tensorscalar(
33-
tensorcontract(
34-
A, ((), linearize(pA)), !conjA,
35-
tΔC, (trivtuple(TO.numind(pA)), ()), false,
36-
((), ()), One(), ba...
33+
= if _needs_tangent(α)
34+
@thunk let
35+
# TODO: this is an inner product implemented as a contraction
36+
# for non-symmetric tensors this might be more efficient like this,
37+
# but for symmetric tensors an intermediate object will anyways be created
38+
# and then it might be more efficient to use an addition and inner product
39+
tΔC = _twist_nocopy(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC)))
40+
_dα = tensorscalar(
41+
tensorcontract(
42+
A, ((), linearize(pA)), !conjA,
43+
tΔC, (trivtuple(TO.numind(pA)), ()), false,
44+
((), ()), One(), ba...
45+
)
3746
)
38-
)
39-
return projectα(_dα)
47+
projectα(_dα)
48+
end
49+
else
50+
ZeroTangent()
4051
end
41-
= @thunk projectβ(inner(C, ΔC))
52+
= _needs_tangent(β) ? @thunk(projectβ(inner(C, ΔC))) : ZeroTangent()
4253
dba = map(_ -> NoTangent(), ba)
4354
return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ, dba...
4455
end
@@ -67,7 +78,7 @@ function ChainRulesCore.rrule(
6778
ipAB = invperm(linearize(pAB))
6879
pΔC = _repartition(ipAB, TO.numout(pA))
6980

70-
dC = @thunk projectC(scale(ΔC, conj(β)))
81+
dC = β === Zero() ? ZeroTangent() : @thunk projectC(scale(ΔC, conj(β)))
7182
dA = @thunk let
7283
ipA = _repartition(invperm(linearize(pA)), A)
7384
conjΔC = conjA
@@ -91,7 +102,7 @@ function ChainRulesCore.rrule(
91102
ipA,
92103
conjA ? α : conj(α), Zero(), ba...
93104
)
94-
return projectA(_dA)
105+
projectA(_dA)
95106
end
96107
dB = @thunk let
97108
ipB = _repartition(invperm(linearize(pB)), B)
@@ -116,14 +127,18 @@ function ChainRulesCore.rrule(
116127
ipB,
117128
conjB ? α : conj(α), Zero(), ba...
118129
)
119-
return projectB(_dB)
130+
projectB(_dB)
120131
end
121-
= @thunk let
122-
# TODO: this result should be AB = (C′ - βC) / α as C′ = βC + αAB
123-
AB = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
124-
return projectα(inner(AB, ΔC))
132+
= if _needs_tangent(α)
133+
@thunk let
134+
# TODO: this result should be AB = (C′ - βC) / α as C′ = βC + αAB
135+
AB = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
136+
projectα(inner(AB, ΔC))
137+
end
138+
else
139+
ZeroTangent()
125140
end
126-
= @thunk projectβ(inner(C, ΔC))
141+
= _needs_tangent(β) ? @thunk(projectβ(inner(C, ΔC))) : ZeroTangent()
127142
dba = map(_ -> NoTangent(), ba)
128143
return NoTangent(), dC,
129144
dA, NoTangent(), NoTangent(),
@@ -149,7 +164,7 @@ function ChainRulesCore.rrule(
149164

150165
function pullback(ΔC′)
151166
ΔC = unthunk(ΔC′)
152-
dC = @thunk projectC(scale(ΔC, conj(β)))
167+
dC = β === Zero() ? ZeroTangent() : @thunk projectC(scale(ΔC, conj(β)))
153168
dA = @thunk let
154169
ip = invperm((linearize(p)..., q[1]..., q[2]...))
155170
pdA = _repartition(ip, A)
@@ -163,15 +178,19 @@ function ChainRulesCore.rrule(
163178
_dA = tensorproduct!(
164179
_dA, ΔC, pΔC, conjA, E, pE, conjA, pdA, conjA ? α : conj(α), Zero(), ba...
165180
)
166-
return projectA(_dA)
181+
projectA(_dA)
167182
end
168-
= @thunk let
169-
# TODO: this result might be easier to compute as:
170-
# C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α
171-
At = tensortrace(A, p, q, conjA)
172-
return projectα(inner(At, ΔC))
183+
= if _needs_tangent(α)
184+
@thunk let
185+
# TODO: this result might be easier to compute as:
186+
# C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α
187+
At = tensortrace(A, p, q, conjA)
188+
projectα(inner(At, ΔC))
189+
end
190+
else
191+
ZeroTangent()
173192
end
174-
= @thunk projectβ(inner(C, ΔC))
193+
= _needs_tangent(β) ? @thunk(projectβ(inner(C, ΔC))) : ZeroTangent()
175194
dba = map(_ -> NoTangent(), ba)
176195
return NoTangent(), dC, dA, NoTangent(), NoTangent(), NoTangent(), dα, dβ, dba...
177196
end

0 commit comments

Comments
 (0)