1+ # To avoid computing rrules for α and β when these aren't needed, we want to have a
2+ # type-stable quick bail-out
3+ _needs_tangent (x) = _needs_tangent (typeof (x))
4+ _needs_tangent (:: Type{<:Number} ) = true
5+ _needs_tangent (:: Type{<:Integer} ) = false
6+ _needs_tangent (:: Type{<:Union{One, Zero}} ) = false
7+
18function ChainRulesCore. rrule (
29 :: typeof (TensorOperations. tensoradd!),
310 C:: AbstractTensorMap ,
@@ -13,32 +20,36 @@ function ChainRulesCore.rrule(
1320
1421 function pullback (ΔC′)
1522 ΔC = unthunk (ΔC′)
16- dC = @thunk projectC (scale (ΔC, conj (β)))
23+ dC = β === Zero () ? ZeroTangent () : @thunk projectC (scale (ΔC, conj (β)))
1724 dA = @thunk let
1825 ipA = invperm (linearize (pA))
1926 pdA = _repartition (ipA, A)
2027 TA = promote_add (ΔC, α)
2128 # TODO : allocator
2229 _dA = tensoralloc_add (TA, ΔC, pdA, conjA, Val (false ))
2330 _dA = tensoradd! (_dA, ΔC, pdA, conjA, conjA ? α : conj (α), Zero (), ba... )
24- return projectA (_dA)
31+ projectA (_dA)
2532 end
26- dα = @thunk let
27- # TODO : this is an inner product implemented as a contraction
28- # for non-symmetric tensors this might be more efficient like this,
29- # but for symmetric tensors an intermediate object will anyways be created
30- # and then it might be more efficient to use an addition and inner product
31- tΔC = _twist_nocopy (ΔC, filter (x -> isdual (space (ΔC, x)), allind (ΔC)))
32- _dα = tensorscalar (
33- tensorcontract (
34- A, ((), linearize (pA)), ! conjA,
35- tΔC, (trivtuple (TO. numind (pA)), ()), false ,
36- ((), ()), One (), ba...
33+ dα = if _needs_tangent (α)
34+ @thunk let
35+ # TODO : this is an inner product implemented as a contraction
36+ # for non-symmetric tensors this might be more efficient like this,
37+ # but for symmetric tensors an intermediate object will anyways be created
38+ # and then it might be more efficient to use an addition and inner product
39+ tΔC = twist (ΔC, filter (x -> isdual (space (ΔC, x)), allind (ΔC)); copy = false )
40+ _dα = tensorscalar (
41+ tensorcontract (
42+ A, ((), linearize (pA)), ! conjA,
43+ tΔC, (trivtuple (TO. numind (pA)), ()), false ,
44+ ((), ()), One (), ba...
45+ )
3746 )
38- )
39- return projectα (_dα)
47+ projectα (_dα)
48+ end
49+ else
50+ ZeroTangent ()
4051 end
41- dβ = @thunk projectβ (inner (C, ΔC))
52+ dβ = _needs_tangent (β) ? @thunk ( projectβ (inner (C, ΔC))) : ZeroTangent ( )
4253 dba = map (_ -> NoTangent (), ba)
4354 return NoTangent (), dC, dA, NoTangent (), NoTangent (), dα, dβ, dba...
4455 end
@@ -67,19 +78,19 @@ function ChainRulesCore.rrule(
6778 ipAB = invperm (linearize (pAB))
6879 pΔC = _repartition (ipAB, TO. numout (pA))
6980
70- dC = @thunk projectC (scale (ΔC, conj (β)))
81+ dC = β === Zero () ? ZeroTangent () : @thunk projectC (scale (ΔC, conj (β)))
7182 dA = @thunk let
7283 ipA = _repartition (invperm (linearize (pA)), A)
7384 conjΔC = conjA
7485 conjB′ = conjA ? conjB : ! conjB
7586 TA = promote_contract (scalartype (ΔC), scalartype (B), scalartype (α))
7687 # TODO : allocator
77- tB = _twist_nocopy (
88+ tB = twist (
7889 B,
7990 TupleTools. vcat (
8091 filter (x -> ! isdual (space (B, x)), pB[1 ]),
8192 filter (x -> isdual (space (B, x)), pB[2 ])
82- )
93+ ); copy = false
8394 )
8495 _dA = tensoralloc_contract (
8596 TA, ΔC, pΔC, conjΔC, tB, reverse (pB), conjB′, ipA, Val (false )
@@ -91,20 +102,20 @@ function ChainRulesCore.rrule(
91102 ipA,
92103 conjA ? α : conj (α), Zero (), ba...
93104 )
94- return projectA (_dA)
105+ projectA (_dA)
95106 end
96107 dB = @thunk let
97108 ipB = _repartition (invperm (linearize (pB)), B)
98109 conjΔC = conjB
99110 conjA′ = conjB ? conjA : ! conjA
100111 TB = promote_contract (scalartype (ΔC), scalartype (A), scalartype (α))
101112 # TODO : allocator
102- tA = _twist_nocopy (
113+ tA = twist (
103114 A,
104115 TupleTools. vcat (
105116 filter (x -> isdual (space (A, x)), pA[1 ]),
106117 filter (x -> ! isdual (space (A, x)), pA[2 ])
107- )
118+ ); copy = false
108119 )
109120 _dB = tensoralloc_contract (
110121 TB, tA, reverse (pA), conjA′, ΔC, pΔC, conjΔC, ipB, Val (false )
@@ -116,14 +127,18 @@ function ChainRulesCore.rrule(
116127 ipB,
117128 conjB ? α : conj (α), Zero (), ba...
118129 )
119- return projectB (_dB)
130+ projectB (_dB)
120131 end
121- dα = @thunk let
122- # TODO : this result should be AB = (C′ - βC) / α as C′ = βC + αAB
123- AB = tensorcontract (A, pA, conjA, B, pB, conjB, pAB, One (), ba... )
124- return projectα (inner (AB, ΔC))
132+ dα = if _needs_tangent (α)
133+ @thunk let
134+ # TODO : this result should be AB = (C′ - βC) / α as C′ = βC + αAB
135+ AB = tensorcontract (A, pA, conjA, B, pB, conjB, pAB, One (), ba... )
136+ projectα (inner (AB, ΔC))
137+ end
138+ else
139+ ZeroTangent ()
125140 end
126- dβ = @thunk projectβ (inner (C, ΔC))
141+ dβ = _needs_tangent (β) ? @thunk ( projectβ (inner (C, ΔC))) : ZeroTangent ( )
127142 dba = map (_ -> NoTangent (), ba)
128143 return NoTangent (), dC,
129144 dA, NoTangent (), NoTangent (),
@@ -149,7 +164,7 @@ function ChainRulesCore.rrule(
149164
150165 function pullback (ΔC′)
151166 ΔC = unthunk (ΔC′)
152- dC = @thunk projectC (scale (ΔC, conj (β)))
167+ dC = β === Zero () ? ZeroTangent () : @thunk projectC (scale (ΔC, conj (β)))
153168 dA = @thunk let
154169 ip = invperm ((linearize (p)... , q[1 ]. .. , q[2 ]. .. ))
155170 pdA = _repartition (ip, A)
@@ -163,15 +178,19 @@ function ChainRulesCore.rrule(
163178 _dA = tensorproduct! (
164179 _dA, ΔC, pΔC, conjA, E, pE, conjA, pdA, conjA ? α : conj (α), Zero (), ba...
165180 )
166- return projectA (_dA)
181+ projectA (_dA)
167182 end
168- dα = @thunk let
169- # TODO : this result might be easier to compute as:
170- # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α
171- At = tensortrace (A, p, q, conjA)
172- return projectα (inner (At, ΔC))
183+ dα = if _needs_tangent (α)
184+ @thunk let
185+ # TODO : this result might be easier to compute as:
186+ # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α
187+ At = tensortrace (A, p, q, conjA)
188+ projectα (inner (At, ΔC))
189+ end
190+ else
191+ ZeroTangent ()
173192 end
174- dβ = @thunk projectβ (inner (C, ΔC))
193+ dβ = _needs_tangent (β) ? @thunk ( projectβ (inner (C, ΔC))) : ZeroTangent ( )
175194 dba = map (_ -> NoTangent (), ba)
176195 return NoTangent (), dC, dA, NoTangent (), NoTangent (), NoTangent (), dα, dβ, dba...
177196 end
@@ -188,10 +207,3 @@ function ChainRulesCore.rrule(::typeof(TensorKit.scalar), t::AbstractTensorMap)
188207 end
189208 return val, scalar_pullback
190209end
191-
192- # temporary function to avoid copies when not needed
193- # TODO : remove once `twist(t; copy=false)` is defined
194- function _twist_nocopy (t, inds; kwargs... )
195- (BraidingStyle (sectortype (t)) isa Bosonic || isempty (inds)) && return t
196- return twist (t, inds; kwargs... )
197- end
0 commit comments