@@ -21,13 +21,17 @@ trivtuple(N) = ntuple(identity, N)
2121
2222# Cannot free intermediate tensors when using AD
2323# Thus we change the forward passes: `istemp=false` and `tensorfree!` is a no-op
24- function ChainRulesCore. rrule (:: typeof (TensorOperations. tensorfree!),
25- allocator= DefaultAllocator ())
24+ function ChainRulesCore. rrule (
25+ :: typeof (TensorOperations. tensorfree!),
26+ allocator = DefaultAllocator ()
27+ )
2628 tensorfree!_pullback (Δargs... ) = (NoTangent (), NoTangent ())
2729 return nothing , tensorfree!_pullback
2830end
29- function ChainRulesCore. rrule (:: typeof (TensorOperations. tensoralloc), ttype, structure,
30- istemp, allocator= DefaultAllocator ())
31+ function ChainRulesCore. rrule (
32+ :: typeof (TensorOperations. tensoralloc), ttype, structure,
33+ istemp, allocator = DefaultAllocator ()
34+ )
3135 output = TensorOperations. tensoralloc (ttype, structure, Val (false ), allocator)
3236 function tensoralloc_pullback (Δargs... )
3337 return (NoTangent (), NoTangent (), NoTangent (), NoTangent (), NoTangent ())
6872# α::Number, β::Number)
6973# return _rrule_tensoradd!(C, A, pA, conjA, α, β, ())
7074# end
71- function ChainRulesCore. rrule (:: typeof (TensorOperations. tensoradd!),
72- C,
73- A, pA:: Index2Tuple , conjA:: Bool ,
74- α:: Number , β:: Number ,
75- ba... )
75+ function ChainRulesCore. rrule (
76+ :: typeof (TensorOperations. tensoradd!),
77+ C,
78+ A, pA:: Index2Tuple , conjA:: Bool ,
79+ α:: Number , β:: Number ,
80+ ba...
81+ )
7682 return _rrule_tensoradd! (C, A, pA, conjA, α, β, ba)
7783end
7884function _rrule_tensoradd! (C, A, pA, conjA, α, β, ba)
@@ -93,16 +99,24 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
9399 return projectA (_dA)
94100 end
95101 dα = @thunk let
96- _dα = tensorscalar (tensorcontract (A, ((), linearize (pA)), ! conjA,
97- ΔC, (trivtuple (numind (pA)), ()), false ,
98- ((), ()), One (), ba... ))
102+ _dα = tensorscalar (
103+ tensorcontract (
104+ A, ((), linearize (pA)), ! conjA,
105+ ΔC, (trivtuple (numind (pA)), ()), false ,
106+ ((), ()), One (), ba...
107+ )
108+ )
99109 return projectα (_dα)
100110 end
101111 dβ = @thunk let
102112 # TODO : consider using `inner`
103- _dβ = tensorscalar (tensorcontract (C, ((), trivtuple (numind (pA))), true ,
104- ΔC, (trivtuple (numind (pA)), ()), false ,
105- ((), ()), One (), ba... ))
113+ _dβ = tensorscalar (
114+ tensorcontract (
115+ C, ((), trivtuple (numind (pA))), true ,
116+ ΔC, (trivtuple (numind (pA)), ()), false ,
117+ ((), ()), One (), ba...
118+ )
119+ )
106120 return projectβ (_dβ)
107121 end
108122 dba = map (_ -> NoTangent (), ba)
@@ -141,13 +155,15 @@ end
141155# α::Number, β::Number)
142156# return _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ())
143157# end
144- function ChainRulesCore. rrule (:: typeof (TensorOperations. tensorcontract!),
145- C,
146- A, pA:: Index2Tuple , conjA:: Bool ,
147- B, pB:: Index2Tuple , conjB:: Bool ,
148- pAB:: Index2Tuple ,
149- α:: Number , β:: Number ,
150- ba... )
158+ function ChainRulesCore. rrule (
159+ :: typeof (TensorOperations. tensorcontract!),
160+ C,
161+ A, pA:: Index2Tuple , conjA:: Bool ,
162+ B, pB:: Index2Tuple , conjB:: Bool ,
163+ pAB:: Index2Tuple ,
164+ α:: Number , β:: Number ,
165+ ba...
166+ )
151167 return _rrule_tensorcontract! (C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
152168end
153169function _rrule_tensorcontract! (C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
@@ -162,52 +178,66 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
162178 function pullback (ΔC′)
163179 ΔC = unthunk (ΔC′)
164180 ipAB = invperm (linearize (pAB))
165- pΔC = (TupleTools. getindices (ipAB, trivtuple (numout (pA))),
166- TupleTools. getindices (ipAB, numout (pA) .+ trivtuple (numin (pB))))
181+ pΔC = (
182+ TupleTools. getindices (ipAB, trivtuple (numout (pA))),
183+ TupleTools. getindices (ipAB, numout (pA) .+ trivtuple (numin (pB))),
184+ )
167185 dC = @thunk projectC (scale (ΔC, conj (β)))
168186 dA = @thunk let
169187 ipA = (invperm (linearize (pA)), ())
170188 conjΔC = conjA
171189 conjB′ = conjA ? conjB : ! conjB
172190 _dA = zerovector (A, promote_contract (scalartype (ΔC), scalartype (B), typeof (α)))
173- _dA = tensorcontract! (_dA,
174- ΔC, pΔC, conjΔC,
175- B, reverse (pB), conjB′,
176- ipA,
177- conjA ? α : conj (α), Zero (), ba... )
191+ _dA = tensorcontract! (
192+ _dA,
193+ ΔC, pΔC, conjΔC,
194+ B, reverse (pB), conjB′,
195+ ipA,
196+ conjA ? α : conj (α), Zero (), ba...
197+ )
178198 return projectA (_dA)
179199 end
180200 dB = @thunk let
181201 ipB = (invperm (linearize (pB)), ())
182202 conjΔC = conjB
183203 conjA′ = conjB ? conjA : ! conjA
184204 _dB = zerovector (B, promote_contract (scalartype (ΔC), scalartype (A), typeof (α)))
185- _dB = tensorcontract! (_dB,
186- A, reverse (pA), conjA′,
187- ΔC, pΔC, conjΔC,
188- ipB,
189- conjB ? α : conj (α), Zero (), ba... )
205+ _dB = tensorcontract! (
206+ _dB,
207+ A, reverse (pA), conjA′,
208+ ΔC, pΔC, conjΔC,
209+ ipB,
210+ conjB ? α : conj (α), Zero (), ba...
211+ )
190212 return projectB (_dB)
191213 end
192214 dα = @thunk let
193215 C_αβ = tensorcontract (A, pA, conjA, B, pB, conjB, pAB, One (), ba... )
194216 # TODO : consider using `inner`
195- _dα = tensorscalar (tensorcontract (C_αβ, ((), trivtuple (numind (pAB))), true ,
196- ΔC, (trivtuple (numind (pAB)), ()), false ,
197- ((), ()), One (), ba... ))
217+ _dα = tensorscalar (
218+ tensorcontract (
219+ C_αβ, ((), trivtuple (numind (pAB))), true ,
220+ ΔC, (trivtuple (numind (pAB)), ()), false ,
221+ ((), ()), One (), ba...
222+ )
223+ )
198224 return projectα (_dα)
199225 end
200226 dβ = @thunk let
201227 # TODO : consider using `inner`
202- _dβ = tensorscalar (tensorcontract (C, ((), trivtuple (numind (pAB))), true ,
203- ΔC, (trivtuple (numind (pAB)), ()), false ,
204- ((), ()), One (), ba... ))
228+ _dβ = tensorscalar (
229+ tensorcontract (
230+ C, ((), trivtuple (numind (pAB))), true ,
231+ ΔC, (trivtuple (numind (pAB)), ()), false ,
232+ ((), ()), One (), ba...
233+ )
234+ )
205235 return projectβ (_dβ)
206236 end
207237 dba = map (_ -> NoTangent (), ba)
208238 return NoTangent (), dC,
209- dA, NoTangent (), NoTangent (), dB, NoTangent (), NoTangent (),
210- NoTangent (), dα, dβ, dba...
239+ dA, NoTangent (), NoTangent (), dB, NoTangent (), NoTangent (),
240+ NoTangent (), dα, dβ, dba...
211241 end
212242
213243 return C′, pullback
@@ -232,10 +262,12 @@ end
232262# α::Number, β::Number)
233263# return _rrule_tensortrace!(C, A, p, q, conjA, α, β, ())
234264# end
235- function ChainRulesCore. rrule (:: typeof (tensortrace!), C,
236- A, p:: Index2Tuple , q:: Index2Tuple , conjA:: Bool ,
237- α:: Number , β:: Number ,
238- ba... )
265+ function ChainRulesCore. rrule (
266+ :: typeof (tensortrace!), C,
267+ A, p:: Index2Tuple , q:: Index2Tuple , conjA:: Bool ,
268+ α:: Number , β:: Number ,
269+ ba...
270+ )
239271 return _rrule_tensortrace! (C, A, p, q, conjA, α, β, ba)
240272end
241273function _rrule_tensortrace! (C, A, p, q, conjA, α, β, ba)
@@ -252,29 +284,43 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
252284 dA = @thunk let
253285 ip = invperm ((linearize (p)... , q[1 ]. .. , q[2 ]. .. ))
254286 Es = map (q[1 ], q[2 ]) do i1, i2
255- return one (TensorOperations. tensoralloc_add (scalartype (A), A,
256- ((i1,), (i2,)), conjA))
287+ return one (
288+ TensorOperations. tensoralloc_add (
289+ scalartype (A), A,
290+ ((i1,), (i2,)), conjA
291+ )
292+ )
257293 end
258294 E = _kron (Es, ba)
259295 _dA = zerovector (A, VectorInterface. promote_scale (ΔC, α))
260- _dA = tensorproduct! (_dA, ΔC, (trivtuple (numind (p)), ()), conjA,
261- E, ((), trivtuple (numind (q))), conjA,
262- (ip, ()),
263- conjA ? α : conj (α), Zero (), ba... )
296+ _dA = tensorproduct! (
297+ _dA, ΔC, (trivtuple (numind (p)), ()), conjA,
298+ E, ((), trivtuple (numind (q))), conjA,
299+ (ip, ()),
300+ conjA ? α : conj (α), Zero (), ba...
301+ )
264302 return projectA (_dA)
265303 end
266304 dα = @thunk let
267305 C_αβ = tensortrace (A, p, q, false , One (), ba... )
268- _dα = tensorscalar (tensorcontract (C_αβ, ((), trivtuple (numind (p))),
269- ! conjA,
270- ΔC, (trivtuple (numind (p)), ()), false ,
271- ((), ()), One (), ba... ))
306+ _dα = tensorscalar (
307+ tensorcontract (
308+ C_αβ, ((), trivtuple (numind (p))),
309+ ! conjA,
310+ ΔC, (trivtuple (numind (p)), ()), false ,
311+ ((), ()), One (), ba...
312+ )
313+ )
272314 return projectα (_dα)
273315 end
274316 dβ = @thunk let
275- _dβ = tensorscalar (tensorcontract (C, ((), trivtuple (numind (p))), true ,
276- ΔC, (trivtuple (numind (p)), ()), false ,
277- ((), ()), One (), ba... ))
317+ _dβ = tensorscalar (
318+ tensorcontract (
319+ C, ((), trivtuple (numind (p))), true ,
320+ ΔC, (trivtuple (numind (p)), ()), false ,
321+ ((), ()), One (), ba...
322+ )
323+ )
278324 return projectβ (_dβ)
279325 end
280326 dba = map (_ -> NoTangent (), ba)
@@ -285,7 +331,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
285331end
286332
287333_kron (Es:: NTuple{1} , ba) = Es[1 ]
288- function _kron (Es:: NTuple{N,Any} , ba) where {N}
334+ function _kron (Es:: NTuple{N, Any} , ba) where {N}
289335 E1 = Es[1 ]
290336 E2 = _kron (Base. tail (Es), ba)
291337 p2 = ((), trivtuple (2 * N - 2 ))
0 commit comments