@@ -47,6 +47,7 @@ function generic_projector(x::T; kw...) where {T}
47
47
# `Foo{Diagaonal{E}}` etc. We assume it has a default constructor that has all fields
48
48
# but if it doesn't `construct` will give a good error message.
49
49
wrapT = T. name. wrapper
50
+ # Official API for this? https://github.com/JuliaLang/julia/issues/35543
50
51
return ProjectTo {wrapT} (; fields_proj... , kw... )
51
52
end
52
53
@@ -352,13 +353,31 @@ function (project::ProjectTo{SparseVector})(dx::AbstractArray)
352
353
reshape (dx, project. axes)
353
354
end
354
355
nzval = map (i -> project. element (dy[i]), project. nzind)
355
- n = length (project. axes[1 ])
356
- return SparseVector (n, project. nzind, nzval)
356
+ return SparseVector (length (dx), project. nzind, nzval)
357
+ end
358
+ function (project:: ProjectTo{SparseVector} )(dx:: SparseVector )
359
+ size (dx) == map (length, project. axes) || throw (_projection_mismatch (project. axes, size (dx)))
360
+ # When sparsity pattern is unchanged, all the time is in checking this,
361
+ # perhaps some simple hash/checksum might be good enough?
362
+ samepattern = project. nzind == dx. nzind
363
+ # samepattern = length(project.nzind) == length(dx.nzind)
364
+ if eltype (dx. nzval) <: project_type (project. element) && samepattern
365
+ return dx
366
+ elseif samepattern
367
+ nzval = map (project. element, dx. nzval)
368
+ SparseVector (length (dx), dx. nzind, nzval)
369
+ else
370
+ nzind = project. nzind
371
+ # Or should we intersect? Can this exploit sorting?
372
+ # nzind = intersect(project.nzind, dx.nzind)
373
+ nzval = map (i -> project. element (dx[i]), nzind)
374
+ return SparseVector (length (dx), nzind, nzval)
375
+ end
357
376
end
358
377
359
378
function ProjectTo (x:: SparseMatrixCSC{T} ) where {T<: Number }
360
379
ProjectTo {SparseMatrixCSC} (; element = ProjectTo (zero (T)), axes = axes (x),
361
- rowvals = rowvals (x), nzranges = nzrange .(Ref (x), axes (x,2 )), colptr = x. colptr)
380
+ rowval = rowvals (x), nzranges = nzrange .(Ref (x), axes (x,2 )), colptr = x. colptr)
362
381
end
363
382
# You need not really store nzranges, you can get them from colptr -- TODO
364
383
# nzrange(S::AbstractSparseMatrixCSC, col::Integer) = getcolptr(S)[col]:(getcolptr(S)[col+1]-1)
@@ -370,15 +389,31 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray)
370
389
size (dx, 2 ) == length (project. axes[2 ]) || throw (_projection_mismatch (project. axes, size (dx)))
371
390
reshape (dx, project. axes)
372
391
end
373
- nzval = Vector {project_type(project.element)} (undef, length (project. rowvals ))
392
+ nzval = Vector {project_type(project.element)} (undef, length (project. rowval ))
374
393
k = 0
375
394
for col in project. axes[2 ]
376
395
for i in project. nzranges[col]
377
- row = project. rowvals [i]
396
+ row = project. rowval [i]
378
397
val = dy[row, col]
379
398
nzval[k+= 1 ] = project. element (val)
380
399
end
381
400
end
382
401
m, n = length .(project. axes)
383
- return SparseMatrixCSC (m, n, project. colptr, project. rowvals, nzval)
402
+ return SparseMatrixCSC (m, n, project. colptr, project. rowval, nzval)
403
+ end
404
+
405
+ function (project:: ProjectTo{SparseMatrixCSC} )(dx:: SparseMatrixCSC )
406
+ size (dx) == map (length, project. axes) || throw (_projection_mismatch (project. axes, size (dx)))
407
+ samepattern = dx. colptr == project. colptr && dx. rowval == project. rowval
408
+ # samepattern = length(dx.colptr) == length(project.colptr) && dx.colptr[end] == project.colptr[end]
409
+ if eltype (dx. nzval) <: project_type (project. element) && samepattern
410
+ return dx
411
+ elseif samepattern
412
+ nzval = map (project. element, dx. nzval)
413
+ m, n = size (dx)
414
+ return SparseMatrixCSC (m, n, dx. colptr, dx. rowval, nzval)
415
+ else
416
+
417
+ invoke (project, Tuple{AbstractArray}, dx)
418
+ end
384
419
end
0 commit comments