@@ -64,28 +64,6 @@ _needs_tangent(::Type{<:Union{One, Zero}}) = false
6464
6565# The current `rrule` design makes sure that the implementation for custom types does
6666# not need to support the backend or allocator arguments
67- # function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
68- # C,
69- # A, pA::Index2Tuple, conjA::Bool,
70- # α::Number, β::Number,
71- # backend, allocator)
72- # val, pb = _rrule_tensoradd!(C, A, pA, conjA, α, β, (backend, allocator))
73- # return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent())
74- # end
75- # function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
76- # C,
77- # A, pA::Index2Tuple, conjA::Bool,
78- # α::Number, β::Number,
79- # backend)
80- # val, pb = _rrule_tensoradd!(C, A, pA, conjA, α, β, (backend,))
81- # return val, ΔC -> (pb(ΔC)..., NoTangent())
82- # end
83- # function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
84- # C,
85- # A, pA::Index2Tuple, conjA::Bool,
86- # α::Number, β::Number)
87- # return _rrule_tensoradd!(C, A, pA, conjA, α, β, ())
88- # end
8967function ChainRulesCore. rrule(
9068 :: typeof (TensorOperations. tensoradd!),
9169 C,
@@ -105,7 +83,11 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
10583
10684 function pullback(ΔC′)
10785 ΔC = unthunk(ΔC′)
108- dC = @thunk projectC(scale(ΔC, conj(β)))
86+ dC = if β === Zero()
87+ ZeroTangent()
88+ else
89+ @thunk projectC(scale(ΔC, conj(β)))
90+ end
10991 dA = @thunk let
11092 ipA = invperm(linearize(pA))
11193 _dA = zerovector(A, VectorInterface. promote_add(ΔC, α))
@@ -148,35 +130,6 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
148130 return C′, pullback
149131end
150132
151- # function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
152- # C,
153- # A, pA::Index2Tuple, conjA::Bool,
154- # B, pB::Index2Tuple, conjB::Bool,
155- # pAB::Index2Tuple,
156- # α::Number, β::Number,
157- # backend, allocator)
158- # val, pb = _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β,
159- # (backend, allocator))
160- # return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent())
161- # end
162- # function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
163- # C,
164- # A, pA::Index2Tuple, conjA::Bool,
165- # B, pB::Index2Tuple, conjB::Bool,
166- # pAB::Index2Tuple,
167- # α::Number, β::Number,
168- # backend)
169- # val, pb = _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, (backend,))
170- # return val, ΔC -> (pb(ΔC)..., NoTangent())
171- # end
172- # function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
173- # C,
174- # A, pA::Index2Tuple, conjA::Bool,
175- # B, pB::Index2Tuple, conjB::Bool,
176- # pAB::Index2Tuple,
177- # α::Number, β::Number)
178- # return _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ())
179- # end
180133function ChainRulesCore. rrule(
181134 :: typeof (TensorOperations. tensorcontract!),
182135 C,
@@ -204,7 +157,11 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
204157 TupleTools. getindices(ipAB, trivtuple(numout(pA))),
205158 TupleTools. getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))),
206159 )
207- dC = @thunk projectC(scale(ΔC, conj(β)))
160+ dC = if β === Zero()
161+ ZeroTangent()
162+ else
163+ @thunk projectC(scale(ΔC, conj(β)))
164+ end
208165 dA = @thunk let
209166 ipA = (invperm(linearize(pA)), ())
210167 conjΔC = conjA
@@ -273,25 +230,6 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
273230 return C′, pullback
274231end
275232
276- # function ChainRulesCore.rrule(::typeof(tensortrace!), C,
277- # A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
278- # α::Number, β::Number,
279- # backend, allocator)
280- # val, pb = _rrule_tensortrace!(C, A, p, q, conjA, α, β, (backend, allocator))
281- # return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent())
282- # end
283- # function ChainRulesCore.rrule(::typeof(tensortrace!), C,
284- # A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
285- # α::Number, β::Number,
286- # backend)
287- # val, pb = _rrule_tensortrace!(C, A, p, q, conjA, α, β, (backend,))
288- # return val, ΔC -> (pb(ΔC)..., NoTangent())
289- # end
290- # function ChainRulesCore.rrule(::typeof(tensortrace!), C,
291- # A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
292- # α::Number, β::Number)
293- # return _rrule_tensortrace!(C, A, p, q, conjA, α, β, ())
294- # end
295233function ChainRulesCore. rrule(
296234 :: typeof (tensortrace!), C,
297235 A, p:: Index2Tuple , q:: Index2Tuple , conjA:: Bool ,
@@ -310,7 +248,11 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
310248
311249 function pullback(ΔC′)
312250 ΔC = unthunk(ΔC′)
313- dC = @thunk projectC(scale(ΔC, conj(β)))
251+ dC = if β === Zero()
252+ ZeroTangent()
253+ else
254+ @thunk projectC(scale(ΔC, conj(β)))
255+ end
314256 dA = @thunk let
315257 ip = invperm((linearize(p). .. , q[1 ]. .. , q[2 ]. .. ))
316258 Es = map(q[1 ], q[2 ]) do i1, i2
0 commit comments