1+ # To avoid computing rrules for α and β when these aren't needed, we want to have a
2+ # type-stable quick bail-out
3+ _needs_tangent(x) = _needs_tangent(typeof(x))
4+ _needs_tangent(:: Type{<:Number} ) = true
5+ _needs_tangent(:: Type{<:Integer} ) = false
6+ _needs_tangent(:: Type{<:Union{One, Zero}} ) = false
7+
18function ChainRulesCore. rrule(
29 :: typeof (TensorOperations. tensoradd!),
310 C:: AbstractTensorMap ,
@@ -13,32 +20,36 @@ function ChainRulesCore.rrule(
1320
1421 function pullback(ΔC′)
1522 ΔC = unthunk(ΔC′)
16- dC = @thunk projectC(scale(ΔC, conj(β)))
23+ dC = β === Zero() ? ZeroTangent() : @thunk projectC(scale(ΔC, conj(β)))
1724 dA = @thunk let
1825 ipA = invperm(linearize(pA))
1926 pdA = _repartition(ipA, A)
2027 TA = promote_add(ΔC, α)
2128 # TODO : allocator
2229 _dA = tensoralloc_add(TA, ΔC, pdA, conjA, Val(false ))
2330 _dA = tensoradd!(_dA, ΔC, pdA, conjA, conjA ? α : conj(α), Zero(), ba... )
24- return projectA(_dA)
31+ projectA(_dA)
2532 end
26- dα = @thunk let
27- # TODO : this is an inner product implemented as a contraction
28- # for non-symmetric tensors this might be more efficient like this,
29- # but for symmetric tensors an intermediate object will anyways be created
30- # and then it might be more efficient to use an addition and inner product
31- tΔC = _twist_nocopy(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC)))
32- _dα = tensorscalar(
33- tensorcontract(
34- A, ((), linearize(pA)), ! conjA,
35- tΔC, (trivtuple(TO. numind(pA)), ()), false ,
36- ((), ()), One(), ba...
33+ dα = if _needs_tangent(α)
34+ @thunk let
35+ # TODO : this is an inner product implemented as a contraction
36+ # for non-symmetric tensors this might be more efficient like this,
37+ # but for symmetric tensors an intermediate object will anyways be created
38+ # and then it might be more efficient to use an addition and inner product
39+ tΔC = _twist_nocopy(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC)))
40+ _dα = tensorscalar(
41+ tensorcontract(
42+ A, ((), linearize(pA)), ! conjA,
43+ tΔC, (trivtuple(TO. numind(pA)), ()), false ,
44+ ((), ()), One(), ba...
45+ )
3746 )
38- )
39- return projectα(_dα)
47+ projectα(_dα)
48+ end
49+ else
50+ ZeroTangent()
4051 end
41- dβ = @thunk projectβ(inner(C, ΔC))
52+ dβ = _needs_tangent(β) ? @thunk( projectβ(inner(C, ΔC))) : ZeroTangent( )
4253 dba = map(_ -> NoTangent(), ba)
4354 return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ, dba...
4455 end
@@ -67,7 +78,7 @@ function ChainRulesCore.rrule(
6778 ipAB = invperm(linearize(pAB))
6879 pΔC = _repartition(ipAB, TO. numout(pA))
6980
70- dC = @thunk projectC(scale(ΔC, conj(β)))
81+ dC = β === Zero() ? ZeroTangent() : @thunk projectC(scale(ΔC, conj(β)))
7182 dA = @thunk let
7283 ipA = _repartition(invperm(linearize(pA)), A)
7384 conjΔC = conjA
@@ -91,7 +102,7 @@ function ChainRulesCore.rrule(
91102 ipA,
92103 conjA ? α : conj(α), Zero(), ba...
93104 )
94- return projectA(_dA)
105+ projectA(_dA)
95106 end
96107 dB = @thunk let
97108 ipB = _repartition(invperm(linearize(pB)), B)
@@ -116,14 +127,18 @@ function ChainRulesCore.rrule(
116127 ipB,
117128 conjB ? α : conj(α), Zero(), ba...
118129 )
119- return projectB(_dB)
130+ projectB(_dB)
120131 end
121- dα = @thunk let
122- # TODO : this result should be AB = (C′ - βC) / α as C′ = βC + αAB
123- AB = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba... )
124- return projectα(inner(AB, ΔC))
132+ dα = if _needs_tangent(α)
133+ @thunk let
134+ # TODO : this result should be AB = (C′ - βC) / α as C′ = βC + αAB
135+ AB = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba... )
136+ projectα(inner(AB, ΔC))
137+ end
138+ else
139+ ZeroTangent()
125140 end
126- dβ = @thunk projectβ(inner(C, ΔC))
141+ dβ = _needs_tangent(β) ? @thunk( projectβ(inner(C, ΔC))) : ZeroTangent( )
127142 dba = map(_ -> NoTangent(), ba)
128143 return NoTangent(), dC,
129144 dA, NoTangent(), NoTangent(),
@@ -149,7 +164,7 @@ function ChainRulesCore.rrule(
149164
150165 function pullback(ΔC′)
151166 ΔC = unthunk(ΔC′)
152- dC = @thunk projectC(scale(ΔC, conj(β)))
167+ dC = β === Zero() ? ZeroTangent() : @thunk projectC(scale(ΔC, conj(β)))
153168 dA = @thunk let
154169 ip = invperm((linearize(p). .. , q[1 ]. .. , q[2 ]. .. ))
155170 pdA = _repartition(ip, A)
@@ -163,15 +178,19 @@ function ChainRulesCore.rrule(
163178 _dA = tensorproduct!(
164179 _dA, ΔC, pΔC, conjA, E, pE, conjA, pdA, conjA ? α : conj(α), Zero(), ba...
165180 )
166- return projectA(_dA)
181+ projectA(_dA)
167182 end
168- dα = @thunk let
169- # TODO : this result might be easier to compute as:
170- # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α
171- At = tensortrace(A, p, q, conjA)
172- return projectα(inner(At, ΔC))
183+ dα = if _needs_tangent(α)
184+ @thunk let
185+ # TODO : this result might be easier to compute as:
186+ # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α
187+ At = tensortrace(A, p, q, conjA)
188+ projectα(inner(At, ΔC))
189+ end
190+ else
191+ ZeroTangent()
173192 end
174- dβ = @thunk projectβ(inner(C, ΔC))
193+ dβ = _needs_tangent(β) ? @thunk( projectβ(inner(C, ΔC))) : ZeroTangent( )
175194 dba = map(_ -> NoTangent(), ba)
176195 return NoTangent(), dC, dA, NoTangent(), NoTangent(), NoTangent(), dα, dβ, dba...
177196 end
0 commit comments