Skip to content

Commit db51820

Browse files
mcabbottoxinabox
authored andcommitted
sparse improvements
1 parent e125797 commit db51820

File tree

2 files changed

+50
-8
lines changed

2 files changed

+50
-8
lines changed

src/projection.jl

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ function generic_projector(x::T; kw...) where {T}
4747
# `Foo{Diagaonal{E}}` etc. We assume it has a default constructor that has all fields
4848
# but if it doesn't `construct` will give a good error message.
4949
wrapT = T.name.wrapper
50+
# Official API for this? https://github.com/JuliaLang/julia/issues/35543
5051
return ProjectTo{wrapT}(; fields_proj..., kw...)
5152
end
5253

@@ -352,13 +353,31 @@ function (project::ProjectTo{SparseVector})(dx::AbstractArray)
352353
reshape(dx, project.axes)
353354
end
354355
nzval = map(i -> project.element(dy[i]), project.nzind)
355-
n = length(project.axes[1])
356-
return SparseVector(n, project.nzind, nzval)
356+
return SparseVector(length(dx), project.nzind, nzval)
357+
end
358+
function (project::ProjectTo{SparseVector})(dx::SparseVector)
359+
size(dx) == map(length, project.axes) || throw(_projection_mismatch(project.axes, size(dx)))
360+
# When sparsity pattern is unchanged, all the time is in checking this,
361+
# perhaps some simple hash/checksum might be good enough?
362+
samepattern = project.nzind == dx.nzind
363+
# samepattern = length(project.nzind) == length(dx.nzind)
364+
if eltype(dx.nzval) <: project_type(project.element) && samepattern
365+
return dx
366+
elseif samepattern
367+
nzval = map(project.element, dx.nzval)
368+
SparseVector(length(dx), dx.nzind, nzval)
369+
else
370+
nzind = project.nzind
371+
# Or should we intersect? Can this exploit sorting?
372+
# nzind = intersect(project.nzind, dx.nzind)
373+
nzval = map(i -> project.element(dx[i]), nzind)
374+
return SparseVector(length(dx), nzind, nzval)
375+
end
357376
end
358377

359378
function ProjectTo(x::SparseMatrixCSC{T}) where {T<:Number}
360379
ProjectTo{SparseMatrixCSC}(; element = ProjectTo(zero(T)), axes = axes(x),
361-
rowvals = rowvals(x), nzranges = nzrange.(Ref(x), axes(x,2)), colptr = x.colptr)
380+
rowval = rowvals(x), nzranges = nzrange.(Ref(x), axes(x,2)), colptr = x.colptr)
362381
end
363382
# You need not really store nzranges, you can get them from colptr -- TODO
364383
# nzrange(S::AbstractSparseMatrixCSC, col::Integer) = getcolptr(S)[col]:(getcolptr(S)[col+1]-1)
@@ -370,15 +389,31 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray)
370389
size(dx, 2) == length(project.axes[2]) || throw(_projection_mismatch(project.axes, size(dx)))
371390
reshape(dx, project.axes)
372391
end
373-
nzval = Vector{project_type(project.element)}(undef, length(project.rowvals))
392+
nzval = Vector{project_type(project.element)}(undef, length(project.rowval))
374393
k = 0
375394
for col in project.axes[2]
376395
for i in project.nzranges[col]
377-
row = project.rowvals[i]
396+
row = project.rowval[i]
378397
val = dy[row, col]
379398
nzval[k+=1] = project.element(val)
380399
end
381400
end
382401
m, n = length.(project.axes)
383-
return SparseMatrixCSC(m, n, project.colptr, project.rowvals, nzval)
402+
return SparseMatrixCSC(m, n, project.colptr, project.rowval, nzval)
403+
end
404+
405+
function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC)
406+
size(dx) == map(length, project.axes) || throw(_projection_mismatch(project.axes, size(dx)))
407+
samepattern = dx.colptr == project.colptr && dx.rowval == project.rowval
408+
# samepattern = length(dx.colptr) == length(project.colptr) && dx.colptr[end] == project.colptr[end]
409+
if eltype(dx.nzval) <: project_type(project.element) && samepattern
410+
return dx
411+
elseif samepattern
412+
nzval = map(project.element, dx.nzval)
413+
m, n = size(dx)
414+
return SparseMatrixCSC(m, n, dx.colptr, dx.rowval, nzval)
415+
else
416+
417+
invoke(project, Tuple{AbstractArray}, dx)
418+
end
384419
end

test/projection.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,14 @@ using OffsetArrays, BenchmarkTools
159159
pv = ProjectTo(v)
160160

161161
@test pv(v) == v
162-
@test pv(v .* (1+im)) v
163-
o = pv(ones(Int, 30, 1))
162+
@test pv(v .* (1+im)) v # same nonzero elements
163+
164+
o = pv(ones(Int, 30, 1)) # dense array
164165
@test nnz(o) == nnz(v)
165166

167+
v2 = sprand(30, 0.7) # different nonzero elements
168+
@test pv(v2) == pv(collect(v2))
169+
166170
# matrix
167171
m = sprand(10, 10, 0.3)
168172
pm = ProjectTo(m)
@@ -172,6 +176,9 @@ using OffsetArrays, BenchmarkTools
172176
om = pm(ones(Int, 10, 10))
173177
@test nnz(om) == nnz(m)
174178

179+
m2 = sprand(10, 10, 0.5)
180+
@test pm(m2) == pm(collect(m2))
181+
175182
@test_throws DimensionMismatch pv(ones(Int, 1, 30))
176183
@test_throws DimensionMismatch pm(ones(Int, 5, 20))
177184
end

0 commit comments

Comments
 (0)