Skip to content

Commit f646b83

Browse files
authored
Some AD cleanup (#238)
* implement ZeroTangent for Zero beta * fix cuTensor compat and bump minor version Co-authored-by: Lukas Devos <lukas.devos@ugent.be>"
1 parent f1fd025 commit f646b83

File tree

2 files changed

+17
-75
lines changed

2 files changed

+17
-75
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "TensorOperations"
22
uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
3-
version = "5.3.2"
3+
version = "5.4"
44
authors = ["Lukas Devos <lukas.devos@ugent.be>", "Maarten Van Damme <maartenvd1994@gmail.com>", "Jutho Haegeman <jutho.haegeman@ugent.be>"]
55

66
[deps]
@@ -49,7 +49,7 @@ StridedViews = "0.3, 0.4"
4949
Test = "1"
5050
TupleTools = "1.6"
5151
VectorInterface = "0.4.1,0.5"
52-
cuTENSOR = ">=2.1.1"
52+
cuTENSOR = "2.1.1"
5353
julia = "1.8"
5454

5555
[extras]

ext/TensorOperationsChainRulesCoreExt.jl

Lines changed: 15 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -64,28 +64,6 @@ _needs_tangent(::Type{<:Union{One, Zero}}) = false
6464

6565
# The current `rrule` design makes sure that the implementation for custom types does
6666
# not need to support the backend or allocator arguments
67-
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
68-
# C,
69-
# A, pA::Index2Tuple, conjA::Bool,
70-
# α::Number, β::Number,
71-
# backend, allocator)
72-
# val, pb = _rrule_tensoradd!(C, A, pA, conjA, α, β, (backend, allocator))
73-
# return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent())
74-
# end
75-
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
76-
# C,
77-
# A, pA::Index2Tuple, conjA::Bool,
78-
# α::Number, β::Number,
79-
# backend)
80-
# val, pb = _rrule_tensoradd!(C, A, pA, conjA, α, β, (backend,))
81-
# return val, ΔC -> (pb(ΔC)..., NoTangent())
82-
# end
83-
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
84-
# C,
85-
# A, pA::Index2Tuple, conjA::Bool,
86-
# α::Number, β::Number)
87-
# return _rrule_tensoradd!(C, A, pA, conjA, α, β, ())
88-
# end
8967
function ChainRulesCore.rrule(
9068
::typeof(TensorOperations.tensoradd!),
9169
C,
@@ -105,7 +83,11 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
10583

10684
function pullback(ΔC′)
10785
ΔC = unthunk(ΔC′)
108-
dC = @thunk projectC(scale(ΔC, conj(β)))
86+
dC = if β === Zero()
87+
ZeroTangent()
88+
else
89+
@thunk projectC(scale(ΔC, conj(β)))
90+
end
10991
dA = @thunk let
11092
ipA = invperm(linearize(pA))
11193
_dA = zerovector(A, VectorInterface.promote_add(ΔC, α))
@@ -148,35 +130,6 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
148130
return C′, pullback
149131
end
150132

151-
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
152-
# C,
153-
# A, pA::Index2Tuple, conjA::Bool,
154-
# B, pB::Index2Tuple, conjB::Bool,
155-
# pAB::Index2Tuple,
156-
# α::Number, β::Number,
157-
# backend, allocator)
158-
# val, pb = _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β,
159-
# (backend, allocator))
160-
# return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent())
161-
# end
162-
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
163-
# C,
164-
# A, pA::Index2Tuple, conjA::Bool,
165-
# B, pB::Index2Tuple, conjB::Bool,
166-
# pAB::Index2Tuple,
167-
# α::Number, β::Number,
168-
# backend)
169-
# val, pb = _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, (backend,))
170-
# return val, ΔC -> (pb(ΔC)..., NoTangent())
171-
# end
172-
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
173-
# C,
174-
# A, pA::Index2Tuple, conjA::Bool,
175-
# B, pB::Index2Tuple, conjB::Bool,
176-
# pAB::Index2Tuple,
177-
# α::Number, β::Number)
178-
# return _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ())
179-
# end
180133
function ChainRulesCore.rrule(
181134
::typeof(TensorOperations.tensorcontract!),
182135
C,
@@ -204,7 +157,11 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
204157
TupleTools.getindices(ipAB, trivtuple(numout(pA))),
205158
TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))),
206159
)
207-
dC = @thunk projectC(scale(ΔC, conj(β)))
160+
dC = if β === Zero()
161+
ZeroTangent()
162+
else
163+
@thunk projectC(scale(ΔC, conj(β)))
164+
end
208165
dA = @thunk let
209166
ipA = (invperm(linearize(pA)), ())
210167
conjΔC = conjA
@@ -273,25 +230,6 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
273230
return C′, pullback
274231
end
275232

276-
# function ChainRulesCore.rrule(::typeof(tensortrace!), C,
277-
# A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
278-
# α::Number, β::Number,
279-
# backend, allocator)
280-
# val, pb = _rrule_tensortrace!(C, A, p, q, conjA, α, β, (backend, allocator))
281-
# return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent())
282-
# end
283-
# function ChainRulesCore.rrule(::typeof(tensortrace!), C,
284-
# A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
285-
# α::Number, β::Number,
286-
# backend)
287-
# val, pb = _rrule_tensortrace!(C, A, p, q, conjA, α, β, (backend,))
288-
# return val, ΔC -> (pb(ΔC)..., NoTangent())
289-
# end
290-
# function ChainRulesCore.rrule(::typeof(tensortrace!), C,
291-
# A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
292-
# α::Number, β::Number)
293-
# return _rrule_tensortrace!(C, A, p, q, conjA, α, β, ())
294-
# end
295233
function ChainRulesCore.rrule(
296234
::typeof(tensortrace!), C,
297235
A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
@@ -310,7 +248,11 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
310248

311249
function pullback(ΔC′)
312250
ΔC = unthunk(ΔC′)
313-
dC = @thunk projectC(scale(ΔC, conj(β)))
251+
dC = if β === Zero()
252+
ZeroTangent()
253+
else
254+
@thunk projectC(scale(ΔC, conj(β)))
255+
end
314256
dA = @thunk let
315257
ip = invperm((linearize(p)..., q[1]..., q[2]...))
316258
Es = map(q[1], q[2]) do i1, i2

0 commit comments

Comments
 (0)