@@ -67,7 +67,7 @@ function (project::ProjectTo{T})(dx::Tangent) where {T}
67
67
end
68
68
69
69
# Used for encoding fields, leaves alone non-diff types:
70
- _maybe_projector (x:: Union{AbstractArray, Number, Ref} ) = ProjectTo (x)
70
+ _maybe_projector (x:: Union{AbstractArray,Number,Ref} ) = ProjectTo (x)
71
71
_maybe_projector (x) = x
72
72
# Used for re-constructing fields, restores non-diff types:
73
73
_maybe_call (f:: ProjectTo , x) = f (x)
161
161
function ProjectTo (xs:: AbstractArray )
162
162
elements = map (ProjectTo, xs)
163
163
if elements isa AbstractArray{<: ProjectTo{<:AbstractZero} }
164
- return ProjectTo {NoTangent} () # short-circuit if all elements project to zero
164
+ return ProjectTo {NoTangent} () # short-circuit if all elements project to zero
165
165
else
166
166
# Arrays of arrays come here, and will apply projectors individually:
167
167
return ProjectTo {AbstractArray} (; elements= elements, axes= axes (xs))
@@ -175,7 +175,9 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
175
175
dx
176
176
else
177
177
for d in 1 : max (M, length (project. axes))
178
- size (dx, d) == length (get (project. axes, d, 1 )) || throw (_projection_mismatch (project. axes, size (dx)))
178
+ if size (dx, d) != length (get (project. axes, d, 1 ))
179
+ throw (_projection_mismatch (project. axes, size (dx)))
180
+ end
179
181
end
180
182
reshape (dx, project. axes)
181
183
end
@@ -185,29 +187,37 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
185
187
T = project_type (project. element)
186
188
S <: T ? dy : map (project. element, dy)
187
189
else
188
- map ((f,y) -> f (y), project. elements, dy)
190
+ map ((f, y) -> f (y), project. elements, dy)
189
191
end
190
192
return dz
191
193
end
192
194
193
195
# Row vectors aren't acceptable as gradients for 1-row matrices:
194
- (project:: ProjectTo{AbstractArray} )(dx:: LinearAlgebra.AdjOrTransAbsVec ) = project (reshape (vec (dx),1 ,:))
196
+ function (project:: ProjectTo{AbstractArray} )(dx:: LinearAlgebra.AdjOrTransAbsVec )
197
+ return project (reshape (vec (dx), 1 , :))
198
+ end
195
199
196
200
# Zero-dimensional arrays -- these have a habit of going missing,
197
201
# although really Ref() is probably a better structure.
198
202
function (project:: ProjectTo{AbstractArray} )(dx:: Number ) # ... so we restore from numbers
199
- project. axes isa Tuple{} || throw (DimensionMismatch (" array with ndims(x) == $(length (project. axes)) > 0 cannot have as gradient dx::Number" ))
203
+ if ! (project. axes isa Tuple{})
204
+ throw (DimensionMismatch (
205
+ " array with ndims(x) == $(length (project. axes)) > 0 cannot have dx::Number" ,
206
+ ))
207
+ end
200
208
return fill (project. element (dx))
201
209
end
202
210
203
211
# Ref -- works like a zero-array, also allows restoration from a number:
204
- ProjectTo (x:: Ref ) = ProjectTo {Ref} (; x = ProjectTo (x[]))
212
+ ProjectTo (x:: Ref ) = ProjectTo {Ref} (; x= ProjectTo (x[]))
205
213
(project:: ProjectTo{Ref} )(dx:: Ref ) = Ref (project. x (dx[]))
206
214
(project:: ProjectTo{Ref} )(dx:: Number ) = Ref (project. x (dx))
207
215
208
216
function _projection_mismatch (axes_x:: Tuple , size_dx:: Tuple )
209
217
size_x = map (length, axes_x)
210
- return DimensionMismatch (" variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx " )
218
+ return DimensionMismatch (
219
+ " variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx "
220
+ )
211
221
end
212
222
213
223
# ####
@@ -217,25 +227,33 @@ end
217
227
# Row vectors
218
228
function ProjectTo (x:: LinearAlgebra.AdjointAbsVec )
219
229
sub = ProjectTo (parent (x))
220
- ProjectTo {Adjoint} (; parent= sub)
230
+ return ProjectTo {Adjoint} (; parent= sub)
221
231
end
222
232
# Note that while [1 2; 3 4]' isa Adjoint, we use ProjectTo{Adjoint} only to encode AdjointAbsVec.
223
233
# Transposed matrices are, like PermutedDimsArray, just a storage detail,
224
234
# but row vectors behave differently, for example [1,2,3]' * [1,2,3] isa Number
225
- (project:: ProjectTo{Adjoint} )(dx:: LinearAlgebra.AdjOrTransAbsVec ) = adjoint (project. parent (adjoint (dx)))
235
+ function (project:: ProjectTo{Adjoint} )(dx:: LinearAlgebra.AdjOrTransAbsVec )
236
+ return adjoint (project. parent (adjoint (dx)))
237
+ end
226
238
function (project:: ProjectTo{Adjoint} )(dx:: AbstractArray )
227
- size (dx,1 ) == 1 && size (dx,2 ) == length (project. parent. axes[1 ]) || throw (_projection_mismatch ((1 : 1 , project. parent. axes... ), size (dx)))
239
+ if size (dx, 1 ) != 1 || size (dx, 2 ) != length (project. parent. axes[1 ])
240
+ throw (_projection_mismatch ((1 : 1 , project. parent. axes... ), size (dx)))
241
+ end
228
242
dy = eltype (dx) <: Real ? vec (dx) : adjoint (dx)
229
243
return adjoint (project. parent (dy))
230
244
end
231
245
232
246
function ProjectTo (x:: LinearAlgebra.TransposeAbsVec )
233
247
sub = ProjectTo (parent (x))
234
- ProjectTo {Transpose} (; parent= sub)
248
+ return ProjectTo {Transpose} (; parent= sub)
249
+ end
250
+ function (project:: ProjectTo{Transpose} )(dx:: LinearAlgebra.AdjOrTransAbsVec )
251
+ return transpose (project. parent (transpose (dx)))
235
252
end
236
- (project:: ProjectTo{Transpose} )(dx:: LinearAlgebra.AdjOrTransAbsVec ) = transpose (project. parent (transpose (dx)))
237
253
function (project:: ProjectTo{Transpose} )(dx:: AbstractArray )
238
- size (dx,1 ) == 1 && size (dx,2 ) == length (project. parent. axes[1 ]) || throw (_projection_mismatch ((1 : 1 , project. parent. axes... ), size (dx)))
254
+ if size (dx, 1 ) != 1 || size (dx, 2 ) != length (project. parent. axes[1 ])
255
+ throw (_projection_mismatch ((1 : 1 , project. parent. axes... ), size (dx)))
256
+ end
239
257
dy = eltype (dx) <: Number ? vec (dx) : transpose (dx)
240
258
return transpose (project. parent (dy))
241
259
end
250
268
(project:: ProjectTo{Diagonal} )(dx:: Diagonal ) = Diagonal (project. diag (dx. diag))
251
269
252
270
# Symmetric
253
- for (SymHerm, chk, fun) in ((:Symmetric , :issymmetric , :transpose ), (:Hermitian , :ishermitian , :adjoint ))
271
+ for (SymHerm, chk, fun) in (
272
+ (:Symmetric , :issymmetric , :transpose ),
273
+ (:Hermitian , :ishermitian , :adjoint ),
274
+ )
254
275
@eval begin
255
276
function ProjectTo (x:: $SymHerm )
256
277
sub = ProjectTo (parent (x))
@@ -268,7 +289,9 @@ for (SymHerm, chk, fun) in ((:Symmetric, :issymmetric, :transpose), (:Hermitian,
268
289
# not clear how broadly it's worthwhile to try to support this.
269
290
function (project:: ProjectTo{$SymHerm} )(dx:: Diagonal )
270
291
sub = project. parent # this is going to be unhappy about the size
271
- sub_one = ProjectTo {project_type(sub)} (; element = sub. element, axes = (sub. axes[1 ],))
292
+ sub_one = ProjectTo {project_type(sub)} (;
293
+ element= sub. element, axes= (sub. axes[1 ],)
294
+ )
272
295
return Diagonal (sub_one (dx. diag))
273
296
end
274
297
end
@@ -279,13 +302,16 @@ for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerT
279
302
@eval begin
280
303
function ProjectTo (x:: $UL )
281
304
sub = ProjectTo (parent (x))
282
- sub isa ProjectTo{<: AbstractZero } && return sub # TODO not necc if UnitUpperTriangular(NoTangent()) etc. worked
305
+ # TODO not nesc if UnitUpperTriangular(NoTangent()) etc. worked
306
+ sub isa ProjectTo{<: AbstractZero } && return sub
283
307
return ProjectTo {$UL} (; parent= sub)
284
308
end
285
309
(project:: ProjectTo{$UL} )(dx:: AbstractArray ) = $ UL (project. parent (dx))
286
310
function (project:: ProjectTo{$UL} )(dx:: Diagonal )
287
311
sub = project. parent
288
- sub_one = ProjectTo {project_type(sub)} (; element = sub. element, axes = (sub. axes[1 ],))
312
+ sub_one = ProjectTo {project_type(sub)} (;
313
+ element= sub. element, axes= (sub. axes[1 ],)
314
+ )
289
315
return Diagonal (sub_one (dx. diag))
290
316
end
291
317
end
@@ -306,7 +332,7 @@ function (project::ProjectTo{Bidiagonal})(dx::Bidiagonal)
306
332
else
307
333
uplo = LinearAlgebra. sym_uplo (project. uplo)
308
334
dv = project. dv (diag (dx))
309
- ev = fill! (similar (dv, length (dv)- 1 ), 0 )
335
+ ev = fill! (similar (dv, length (dv) - 1 ), 0 )
310
336
return Bidiagonal (dv, ev, uplo)
311
337
end
312
338
end
321
347
322
348
# another strategy is just to use the AbstractArray method
323
349
function ProjectTo (x:: Tridiagonal{T} ) where {T<: Number }
324
- notparent = invoke (ProjectTo, Tuple{AbstractArray{T}} where T<: Number , x)
325
- return ProjectTo {Tridiagonal} (; notparent = notparent)
350
+ notparent = invoke (ProjectTo, Tuple{AbstractArray{T}} where { T<: Number } , x)
351
+ return ProjectTo {Tridiagonal} (; notparent= notparent)
326
352
end
327
353
function (project:: ProjectTo{Tridiagonal} )(dx:: AbstractArray )
328
354
dy = project. notparent (dx)
@@ -340,20 +366,26 @@ using SparseArrays
340
366
# This implementation very naiive, can probably be made more efficient.
341
367
342
368
function ProjectTo (x:: SparseVector{T} ) where {T<: Number }
343
- return ProjectTo {SparseVector} (; element = ProjectTo (zero (T)), nzind = x. nzind, axes = axes (x))
369
+ return ProjectTo {SparseVector} (;
370
+ element= ProjectTo (zero (T)), nzind= x. nzind, axes= axes (x)
371
+ )
344
372
end
345
373
function (project:: ProjectTo{SparseVector} )(dx:: AbstractArray )
346
374
dy = if axes (dx) == project. axes
347
375
dx
348
376
else
349
- size (dx, 1 ) == length (project. axes[1 ]) || throw (_projection_mismatch (project. axes, size (dx)))
377
+ if size (dx, 1 ) != length (project. axes[1 ])
378
+ throw (_projection_mismatch (project. axes, size (dx)))
379
+ end
350
380
reshape (dx, project. axes)
351
381
end
352
382
nzval = map (i -> project. element (dy[i]), project. nzind)
353
383
return SparseVector (length (dx), project. nzind, nzval)
354
384
end
355
385
function (project:: ProjectTo{SparseVector} )(dx:: SparseVector )
356
- size (dx) == map (length, project. axes) || throw (_projection_mismatch (project. axes, size (dx)))
386
+ if size (dx) != map (length, project. axes)
387
+ throw (_projection_mismatch (project. axes, size (dx)))
388
+ end
357
389
# When sparsity pattern is unchanged, all the time is in checking this,
358
390
# perhaps some simple hash/checksum might be good enough?
359
391
samepattern = project. nzind == dx. nzind
@@ -373,17 +405,23 @@ function (project::ProjectTo{SparseVector})(dx::SparseVector)
373
405
end
374
406
375
407
function ProjectTo (x:: SparseMatrixCSC{T} ) where {T<: Number }
376
- ProjectTo {SparseMatrixCSC} (; element = ProjectTo (zero (T)), axes = axes (x),
377
- rowval = rowvals (x), nzranges = nzrange .(Ref (x), axes (x,2 )), colptr = x. colptr)
408
+ return ProjectTo {SparseMatrixCSC} (;
409
+ element= ProjectTo (zero (T)),
410
+ axes= axes (x),
411
+ rowval= rowvals (x),
412
+ nzranges= nzrange .(Ref (x), axes (x, 2 )),
413
+ colptr= x. colptr,
414
+ )
378
415
end
379
416
# You need not really store nzranges, you can get them from colptr -- TODO
380
417
# nzrange(S::AbstractSparseMatrixCSC, col::Integer) = getcolptr(S)[col]:(getcolptr(S)[col+1]-1)
381
418
function (project:: ProjectTo{SparseMatrixCSC} )(dx:: AbstractArray )
382
419
dy = if axes (dx) == project. axes
383
420
dx
384
421
else
385
- size (dx, 1 ) == length (project. axes[1 ]) || throw (_projection_mismatch (project. axes, size (dx)))
386
- size (dx, 2 ) == length (project. axes[2 ]) || throw (_projection_mismatch (project. axes, size (dx)))
422
+ if size (dx) != (length (project. axes[1 ]), length (project. axes[2 ]))
423
+ throw (_projection_mismatch (project. axes, size (dx)))
424
+ end
387
425
reshape (dx, project. axes)
388
426
end
389
427
nzval = Vector {project_type(project.element)} (undef, length (project. rowval))
@@ -392,15 +430,17 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray)
392
430
for i in project. nzranges[col]
393
431
row = project. rowval[i]
394
432
val = dy[row, col]
395
- nzval[k+= 1 ] = project. element (val)
433
+ nzval[k += 1 ] = project. element (val)
396
434
end
397
435
end
398
436
m, n = map (length, project. axes)
399
437
return SparseMatrixCSC (m, n, project. colptr, project. rowval, nzval)
400
438
end
401
439
402
440
function (project:: ProjectTo{SparseMatrixCSC} )(dx:: SparseMatrixCSC )
403
- size (dx) == map (length, project. axes) || throw (_projection_mismatch (project. axes, size (dx)))
441
+ if size (dx) != map (length, project. axes)
442
+ throw (_projection_mismatch (project. axes, size (dx)))
443
+ end
404
444
samepattern = dx. colptr == project. colptr && dx. rowval == project. rowval
405
445
# samepattern = length(dx.colptr) == length(project.colptr) && dx.colptr[end] == project.colptr[end]
406
446
if eltype (dx) <: project_type (project. element) && samepattern
0 commit comments