Skip to content

Commit 0dedf92

Browse files
committed
code style
1 parent 2e06e54 commit 0dedf92

File tree

2 files changed

+185
-145
lines changed

2 files changed

+185
-145
lines changed

src/projection.jl

Lines changed: 70 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ function (project::ProjectTo{T})(dx::Tangent) where {T}
6767
end
6868

6969
# Used for encoding fields, leaves alone non-diff types:
70-
_maybe_projector(x::Union{AbstractArray, Number, Ref}) = ProjectTo(x)
70+
_maybe_projector(x::Union{AbstractArray,Number,Ref}) = ProjectTo(x)
7171
_maybe_projector(x) = x
7272
# Used for re-constructing fields, restores non-diff types:
7373
_maybe_call(f::ProjectTo, x) = f(x)
@@ -161,7 +161,7 @@ end
161161
function ProjectTo(xs::AbstractArray)
162162
elements = map(ProjectTo, xs)
163163
if elements isa AbstractArray{<:ProjectTo{<:AbstractZero}}
164-
return ProjectTo{NoTangent}() # short-circuit if all elements project to zero
164+
return ProjectTo{NoTangent}() # short-circuit if all elements project to zero
165165
else
166166
# Arrays of arrays come here, and will apply projectors individually:
167167
return ProjectTo{AbstractArray}(; elements=elements, axes=axes(xs))
@@ -175,7 +175,9 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
175175
dx
176176
else
177177
for d in 1:max(M, length(project.axes))
178-
size(dx, d) == length(get(project.axes, d, 1)) || throw(_projection_mismatch(project.axes, size(dx)))
178+
if size(dx, d) != length(get(project.axes, d, 1))
179+
throw(_projection_mismatch(project.axes, size(dx)))
180+
end
179181
end
180182
reshape(dx, project.axes)
181183
end
@@ -185,29 +187,37 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
185187
T = project_type(project.element)
186188
S <: T ? dy : map(project.element, dy)
187189
else
188-
map((f,y) -> f(y), project.elements, dy)
190+
map((f, y) -> f(y), project.elements, dy)
189191
end
190192
return dz
191193
end
192194

193195
# Row vectors aren't acceptable as gradients for 1-row matrices:
194-
(project::ProjectTo{AbstractArray})(dx::LinearAlgebra.AdjOrTransAbsVec) = project(reshape(vec(dx),1,:))
196+
function (project::ProjectTo{AbstractArray})(dx::LinearAlgebra.AdjOrTransAbsVec)
197+
return project(reshape(vec(dx), 1, :))
198+
end
195199

196200
# Zero-dimensional arrays -- these have a habit of going missing,
197201
# although really Ref() is probably a better structure.
198202
function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore from numbers
199-
project.axes isa Tuple{} || throw(DimensionMismatch("array with ndims(x) == $(length(project.axes)) > 0 cannot have as gradient dx::Number"))
203+
if !(project.axes isa Tuple{})
204+
throw(DimensionMismatch(
205+
"array with ndims(x) == $(length(project.axes)) > 0 cannot have dx::Number",
206+
))
207+
end
200208
return fill(project.element(dx))
201209
end
202210

203211
# Ref -- works like a zero-array, also allows restoration from a number:
204-
ProjectTo(x::Ref) = ProjectTo{Ref}(; x = ProjectTo(x[]))
212+
ProjectTo(x::Ref) = ProjectTo{Ref}(; x=ProjectTo(x[]))
205213
(project::ProjectTo{Ref})(dx::Ref) = Ref(project.x(dx[]))
206214
(project::ProjectTo{Ref})(dx::Number) = Ref(project.x(dx))
207215

208216
function _projection_mismatch(axes_x::Tuple, size_dx::Tuple)
209217
size_x = map(length, axes_x)
210-
return DimensionMismatch("variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx")
218+
return DimensionMismatch(
219+
"variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx"
220+
)
211221
end
212222

213223
#####
@@ -217,25 +227,33 @@ end
217227
# Row vectors
218228
function ProjectTo(x::LinearAlgebra.AdjointAbsVec)
219229
sub = ProjectTo(parent(x))
220-
ProjectTo{Adjoint}(; parent=sub)
230+
return ProjectTo{Adjoint}(; parent=sub)
221231
end
222232
# Note that while [1 2; 3 4]' isa Adjoint, we use ProjectTo{Adjoint} only to encode AdjointAbsVec.
223233
# Transposed matrices are, like PermutedDimsArray, just a storage detail,
224234
# but row vectors behave differently, for example [1,2,3]' * [1,2,3] isa Number
225-
(project::ProjectTo{Adjoint})(dx::LinearAlgebra.AdjOrTransAbsVec) = adjoint(project.parent(adjoint(dx)))
235+
function (project::ProjectTo{Adjoint})(dx::LinearAlgebra.AdjOrTransAbsVec)
236+
return adjoint(project.parent(adjoint(dx)))
237+
end
226238
function (project::ProjectTo{Adjoint})(dx::AbstractArray)
227-
size(dx,1) == 1 && size(dx,2) == length(project.parent.axes[1]) || throw(_projection_mismatch((1:1, project.parent.axes...), size(dx)))
239+
if size(dx, 1) != 1 || size(dx, 2) != length(project.parent.axes[1])
240+
throw(_projection_mismatch((1:1, project.parent.axes...), size(dx)))
241+
end
228242
dy = eltype(dx) <: Real ? vec(dx) : adjoint(dx)
229243
return adjoint(project.parent(dy))
230244
end
231245

232246
function ProjectTo(x::LinearAlgebra.TransposeAbsVec)
233247
sub = ProjectTo(parent(x))
234-
ProjectTo{Transpose}(; parent=sub)
248+
return ProjectTo{Transpose}(; parent=sub)
249+
end
250+
function (project::ProjectTo{Transpose})(dx::LinearAlgebra.AdjOrTransAbsVec)
251+
return transpose(project.parent(transpose(dx)))
235252
end
236-
(project::ProjectTo{Transpose})(dx::LinearAlgebra.AdjOrTransAbsVec) = transpose(project.parent(transpose(dx)))
237253
function (project::ProjectTo{Transpose})(dx::AbstractArray)
238-
size(dx,1) == 1 && size(dx,2) == length(project.parent.axes[1]) || throw(_projection_mismatch((1:1, project.parent.axes...), size(dx)))
254+
if size(dx, 1) != 1 || size(dx, 2) != length(project.parent.axes[1])
255+
throw(_projection_mismatch((1:1, project.parent.axes...), size(dx)))
256+
end
239257
dy = eltype(dx) <: Number ? vec(dx) : transpose(dx)
240258
return transpose(project.parent(dy))
241259
end
@@ -250,7 +268,10 @@ end
250268
(project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag))
251269

252270
# Symmetric
253-
for (SymHerm, chk, fun) in ((:Symmetric, :issymmetric, :transpose), (:Hermitian, :ishermitian, :adjoint))
271+
for (SymHerm, chk, fun) in (
272+
(:Symmetric, :issymmetric, :transpose),
273+
(:Hermitian, :ishermitian, :adjoint),
274+
)
254275
@eval begin
255276
function ProjectTo(x::$SymHerm)
256277
sub = ProjectTo(parent(x))
@@ -268,7 +289,9 @@ for (SymHerm, chk, fun) in ((:Symmetric, :issymmetric, :transpose), (:Hermitian,
268289
# not clear how broadly it's worthwhile to try to support this.
269290
function (project::ProjectTo{$SymHerm})(dx::Diagonal)
270291
sub = project.parent # this is going to be unhappy about the size
271-
sub_one = ProjectTo{project_type(sub)}(; element = sub.element, axes = (sub.axes[1],))
292+
sub_one = ProjectTo{project_type(sub)}(;
293+
element=sub.element, axes=(sub.axes[1],)
294+
)
272295
return Diagonal(sub_one(dx.diag))
273296
end
274297
end
@@ -279,13 +302,16 @@ for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerT
279302
@eval begin
280303
function ProjectTo(x::$UL)
281304
sub = ProjectTo(parent(x))
282-
sub isa ProjectTo{<:AbstractZero} && return sub # TODO not necc if UnitUpperTriangular(NoTangent()) etc. worked
305+
# TODO not nesc if UnitUpperTriangular(NoTangent()) etc. worked
306+
sub isa ProjectTo{<:AbstractZero} && return sub
283307
return ProjectTo{$UL}(; parent=sub)
284308
end
285309
(project::ProjectTo{$UL})(dx::AbstractArray) = $UL(project.parent(dx))
286310
function (project::ProjectTo{$UL})(dx::Diagonal)
287311
sub = project.parent
288-
sub_one = ProjectTo{project_type(sub)}(; element = sub.element, axes = (sub.axes[1],))
312+
sub_one = ProjectTo{project_type(sub)}(;
313+
element=sub.element, axes=(sub.axes[1],)
314+
)
289315
return Diagonal(sub_one(dx.diag))
290316
end
291317
end
@@ -306,7 +332,7 @@ function (project::ProjectTo{Bidiagonal})(dx::Bidiagonal)
306332
else
307333
uplo = LinearAlgebra.sym_uplo(project.uplo)
308334
dv = project.dv(diag(dx))
309-
ev = fill!(similar(dv, length(dv)-1), 0)
335+
ev = fill!(similar(dv, length(dv) - 1), 0)
310336
return Bidiagonal(dv, ev, uplo)
311337
end
312338
end
@@ -321,8 +347,8 @@ end
321347

322348
# another strategy is just to use the AbstractArray method
323349
function ProjectTo(x::Tridiagonal{T}) where {T<:Number}
324-
notparent = invoke(ProjectTo, Tuple{AbstractArray{T}} where T<:Number, x)
325-
return ProjectTo{Tridiagonal}(; notparent = notparent)
350+
notparent = invoke(ProjectTo, Tuple{AbstractArray{T}} where {T<:Number}, x)
351+
return ProjectTo{Tridiagonal}(; notparent=notparent)
326352
end
327353
function (project::ProjectTo{Tridiagonal})(dx::AbstractArray)
328354
dy = project.notparent(dx)
@@ -340,20 +366,26 @@ using SparseArrays
340366
# This implementation very naiive, can probably be made more efficient.
341367

342368
function ProjectTo(x::SparseVector{T}) where {T<:Number}
343-
return ProjectTo{SparseVector}(; element = ProjectTo(zero(T)), nzind = x.nzind, axes = axes(x))
369+
return ProjectTo{SparseVector}(;
370+
element=ProjectTo(zero(T)), nzind=x.nzind, axes=axes(x)
371+
)
344372
end
345373
function (project::ProjectTo{SparseVector})(dx::AbstractArray)
346374
dy = if axes(dx) == project.axes
347375
dx
348376
else
349-
size(dx, 1) == length(project.axes[1]) || throw(_projection_mismatch(project.axes, size(dx)))
377+
if size(dx, 1) != length(project.axes[1])
378+
throw(_projection_mismatch(project.axes, size(dx)))
379+
end
350380
reshape(dx, project.axes)
351381
end
352382
nzval = map(i -> project.element(dy[i]), project.nzind)
353383
return SparseVector(length(dx), project.nzind, nzval)
354384
end
355385
function (project::ProjectTo{SparseVector})(dx::SparseVector)
356-
size(dx) == map(length, project.axes) || throw(_projection_mismatch(project.axes, size(dx)))
386+
if size(dx) != map(length, project.axes)
387+
throw(_projection_mismatch(project.axes, size(dx)))
388+
end
357389
# When sparsity pattern is unchanged, all the time is in checking this,
358390
# perhaps some simple hash/checksum might be good enough?
359391
samepattern = project.nzind == dx.nzind
@@ -373,17 +405,23 @@ function (project::ProjectTo{SparseVector})(dx::SparseVector)
373405
end
374406

375407
function ProjectTo(x::SparseMatrixCSC{T}) where {T<:Number}
376-
ProjectTo{SparseMatrixCSC}(; element = ProjectTo(zero(T)), axes = axes(x),
377-
rowval = rowvals(x), nzranges = nzrange.(Ref(x), axes(x,2)), colptr = x.colptr)
408+
return ProjectTo{SparseMatrixCSC}(;
409+
element=ProjectTo(zero(T)),
410+
axes=axes(x),
411+
rowval=rowvals(x),
412+
nzranges=nzrange.(Ref(x), axes(x, 2)),
413+
colptr=x.colptr,
414+
)
378415
end
379416
# You need not really store nzranges, you can get them from colptr -- TODO
380417
# nzrange(S::AbstractSparseMatrixCSC, col::Integer) = getcolptr(S)[col]:(getcolptr(S)[col+1]-1)
381418
function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray)
382419
dy = if axes(dx) == project.axes
383420
dx
384421
else
385-
size(dx, 1) == length(project.axes[1]) || throw(_projection_mismatch(project.axes, size(dx)))
386-
size(dx, 2) == length(project.axes[2]) || throw(_projection_mismatch(project.axes, size(dx)))
422+
if size(dx) != (length(project.axes[1]), length(project.axes[2]))
423+
throw(_projection_mismatch(project.axes, size(dx)))
424+
end
387425
reshape(dx, project.axes)
388426
end
389427
nzval = Vector{project_type(project.element)}(undef, length(project.rowval))
@@ -392,15 +430,17 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray)
392430
for i in project.nzranges[col]
393431
row = project.rowval[i]
394432
val = dy[row, col]
395-
nzval[k+=1] = project.element(val)
433+
nzval[k += 1] = project.element(val)
396434
end
397435
end
398436
m, n = map(length, project.axes)
399437
return SparseMatrixCSC(m, n, project.colptr, project.rowval, nzval)
400438
end
401439

402440
function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC)
403-
size(dx) == map(length, project.axes) || throw(_projection_mismatch(project.axes, size(dx)))
441+
if size(dx) != map(length, project.axes)
442+
throw(_projection_mismatch(project.axes, size(dx)))
443+
end
404444
samepattern = dx.colptr == project.colptr && dx.rowval == project.rowval
405445
# samepattern = length(dx.colptr) == length(project.colptr) && dx.colptr[end] == project.colptr[end]
406446
if eltype(dx) <: project_type(project.element) && samepattern

0 commit comments

Comments
 (0)