|
46 | 46 |
|
47 | 47 | # The current `rrule` design makes sure that the implementation for custom types does |
48 | 48 | # not need to support the backend or allocator arguments |
| 49 | +# function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!), |
| 50 | +# C, |
| 51 | +# A, pA::Index2Tuple, conjA::Bool, |
| 52 | +# α::Number, β::Number, |
| 53 | +# backend, allocator) |
| 54 | +# val, pb = _rrule_tensoradd!(C, A, pA, conjA, α, β, (backend, allocator)) |
| 55 | +# return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent()) |
| 56 | +# end |
| 57 | +# function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!), |
| 58 | +# C, |
| 59 | +# A, pA::Index2Tuple, conjA::Bool, |
| 60 | +# α::Number, β::Number, |
| 61 | +# backend) |
| 62 | +# val, pb = _rrule_tensoradd!(C, A, pA, conjA, α, β, (backend,)) |
| 63 | +# return val, ΔC -> (pb(ΔC)..., NoTangent()) |
| 64 | +# end |
| 65 | +# function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!), |
| 66 | +# C, |
| 67 | +# A, pA::Index2Tuple, conjA::Bool, |
| 68 | +# α::Number, β::Number) |
| 69 | +# return _rrule_tensoradd!(C, A, pA, conjA, α, β, ()) |
| 70 | +# end |
49 | 71 | function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!), |
50 | 72 | C, |
51 | 73 | A, pA::Index2Tuple, conjA::Bool, |
52 | 74 | α::Number, β::Number, |
53 | | - backend, allocator) |
54 | | - val, pb = _rrule_tensoradd!(C, A, pA, conjA, α, β, (backend, allocator)) |
55 | | - return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent()) |
56 | | -end |
57 | | -function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!), |
58 | | - C, |
59 | | - A, pA::Index2Tuple, conjA::Bool, |
60 | | - α::Number, β::Number, |
61 | | - backend) |
62 | | - val, pb = _rrule_tensoradd!(C, A, pA, conjA, α, β, (backend,)) |
63 | | - return val, ΔC -> (pb(ΔC)..., NoTangent()) |
64 | | -end |
65 | | -function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!), |
66 | | - C, |
67 | | - A, pA::Index2Tuple, conjA::Bool, |
68 | | - α::Number, β::Number) |
69 | | - return _rrule_tensoradd!(C, A, pA, conjA, α, β, ()) |
| 75 | + ba...) |
| 76 | + return _rrule_tensoradd!(C, A, pA, conjA, α, β, ba) |
70 | 77 | end |
71 | 78 | function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba) |
72 | 79 | C′ = tensoradd!(copy(C), A, pA, conjA, α, β, ba...) |
@@ -98,40 +105,50 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba) |
98 | 105 | ((), ()), One(), ba...)) |
99 | 106 | return projectβ(_dβ) |
100 | 107 | end |
101 | | - return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ |
| 108 | + dba = map(_ -> NoTangent(), ba) |
| 109 | + return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ, dba... |
102 | 110 | end |
103 | 111 |
|
104 | 112 | return C′, pullback |
105 | 113 | end |
106 | 114 |
|
| 115 | +# function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!), |
| 116 | +# C, |
| 117 | +# A, pA::Index2Tuple, conjA::Bool, |
| 118 | +# B, pB::Index2Tuple, conjB::Bool, |
| 119 | +# pAB::Index2Tuple, |
| 120 | +# α::Number, β::Number, |
| 121 | +# backend, allocator) |
| 122 | +# val, pb = _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, |
| 123 | +# (backend, allocator)) |
| 124 | +# return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent()) |
| 125 | +# end |
| 126 | +# function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!), |
| 127 | +# C, |
| 128 | +# A, pA::Index2Tuple, conjA::Bool, |
| 129 | +# B, pB::Index2Tuple, conjB::Bool, |
| 130 | +# pAB::Index2Tuple, |
| 131 | +# α::Number, β::Number, |
| 132 | +# backend) |
| 133 | +# val, pb = _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, (backend,)) |
| 134 | +# return val, ΔC -> (pb(ΔC)..., NoTangent()) |
| 135 | +# end |
| 136 | +# function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!), |
| 137 | +# C, |
| 138 | +# A, pA::Index2Tuple, conjA::Bool, |
| 139 | +# B, pB::Index2Tuple, conjB::Bool, |
| 140 | +# pAB::Index2Tuple, |
| 141 | +# α::Number, β::Number) |
| 142 | +# return _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ()) |
| 143 | +# end |
107 | 144 | function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!), |
108 | 145 | C, |
109 | 146 | A, pA::Index2Tuple, conjA::Bool, |
110 | 147 | B, pB::Index2Tuple, conjB::Bool, |
111 | 148 | pAB::Index2Tuple, |
112 | 149 | α::Number, β::Number, |
113 | | - backend, allocator) |
114 | | - val, pb = _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, |
115 | | - (backend, allocator)) |
116 | | - return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent()) |
117 | | -end |
118 | | -function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!), |
119 | | - C, |
120 | | - A, pA::Index2Tuple, conjA::Bool, |
121 | | - B, pB::Index2Tuple, conjB::Bool, |
122 | | - pAB::Index2Tuple, |
123 | | - α::Number, β::Number, |
124 | | - backend) |
125 | | - val, pb = _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, (backend,)) |
126 | | - return val, ΔC -> (pb(ΔC)..., NoTangent()) |
127 | | -end |
128 | | -function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!), |
129 | | - C, |
130 | | - A, pA::Index2Tuple, conjA::Bool, |
131 | | - B, pB::Index2Tuple, conjB::Bool, |
132 | | - pAB::Index2Tuple, |
133 | | - α::Number, β::Number) |
134 | | - return _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ()) |
| 150 | + ba...) |
| 151 | + return _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba) |
135 | 152 | end |
136 | 153 | function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba) |
137 | 154 | C′ = tensorcontract!(copy(C), A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) |
@@ -187,32 +204,39 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba) |
187 | 204 | ((), ()), One(), ba...)) |
188 | 205 | return projectβ(_dβ) |
189 | 206 | end |
| 207 | + dba = map(_ -> NoTangent(), ba) |
190 | 208 | return NoTangent(), dC, |
191 | 209 | dA, NoTangent(), NoTangent(), dB, NoTangent(), NoTangent(), |
192 | | - NoTangent(), dα, dβ |
| 210 | + NoTangent(), dα, dβ, dba... |
193 | 211 | end |
194 | 212 |
|
195 | 213 | return C′, pullback |
196 | 214 | end |
197 | 215 |
|
| 216 | +# function ChainRulesCore.rrule(::typeof(tensortrace!), C, |
| 217 | +# A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, |
| 218 | +# α::Number, β::Number, |
| 219 | +# backend, allocator) |
| 220 | +# val, pb = _rrule_tensortrace!(C, A, p, q, conjA, α, β, (backend, allocator)) |
| 221 | +# return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent()) |
| 222 | +# end |
| 223 | +# function ChainRulesCore.rrule(::typeof(tensortrace!), C, |
| 224 | +# A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, |
| 225 | +# α::Number, β::Number, |
| 226 | +# backend) |
| 227 | +# val, pb = _rrule_tensortrace!(C, A, p, q, conjA, α, β, (backend,)) |
| 228 | +# return val, ΔC -> (pb(ΔC)..., NoTangent()) |
| 229 | +# end |
| 230 | +# function ChainRulesCore.rrule(::typeof(tensortrace!), C, |
| 231 | +# A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, |
| 232 | +# α::Number, β::Number) |
| 233 | +# return _rrule_tensortrace!(C, A, p, q, conjA, α, β, ()) |
| 234 | +# end |
198 | 235 | function ChainRulesCore.rrule(::typeof(tensortrace!), C, |
199 | 236 | A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, |
200 | 237 | α::Number, β::Number, |
201 | | - backend, allocator) |
202 | | - val, pb = _rrule_tensortrace!(C, A, p, q, conjA, α, β, (backend, allocator)) |
203 | | - return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent()) |
204 | | -end |
205 | | -function ChainRulesCore.rrule(::typeof(tensortrace!), C, |
206 | | - A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, |
207 | | - α::Number, β::Number, |
208 | | - backend) |
209 | | - val, pb = _rrule_tensortrace!(C, A, p, q, conjA, α, β, (backend,)) |
210 | | - return val, ΔC -> (pb(ΔC)..., NoTangent()) |
211 | | -end |
212 | | -function ChainRulesCore.rrule(::typeof(tensortrace!), C, |
213 | | - A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, |
214 | | - α::Number, β::Number) |
215 | | - return _rrule_tensortrace!(C, A, p, q, conjA, α, β, ()) |
| 238 | + ba...) |
| 239 | + return _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba) |
216 | 240 | end |
217 | 241 | function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba) |
218 | 242 | C′ = tensortrace!(copy(C), A, p, q, conjA, α, β, ba...) |
@@ -253,7 +277,8 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba) |
253 | 277 | ((), ()), One(), ba...)) |
254 | 278 | return projectβ(_dβ) |
255 | 279 | end |
256 | | - return NoTangent(), dC, dA, NoTangent(), NoTangent(), NoTangent(), dα, dβ |
| 280 | + dba = map(_ -> NoTangent(), ba) |
| 281 | + return NoTangent(), dC, dA, NoTangent(), NoTangent(), NoTangent(), dα, dβ, dba... |
257 | 282 | end |
258 | 283 |
|
259 | 284 | return C′, pullback |
|
0 commit comments