Skip to content

Commit 22a875b

Browse files
authored
Fix spdiagm with specified pairs (#2784)
* Fix spdiagm with specified pairs * Should subtype keys * Remove DS_store
1 parent a4a7af4 commit 22a875b

File tree

2 files changed

+66
-9
lines changed

2 files changed

+66
-9
lines changed

lib/cusparse/array.jl

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -336,14 +336,55 @@ function SparseArrays.sparsevec(I::CuArray{Ti}, V::CuArray{Tv}, n::Integer) wher
336336
CuSparseVector(I, V, n)
337337
end
338338

339-
function SparseArrays.spdiagm(v::CuVector{Tv}) where {Tv}
340-
nzVal = v
341-
N = Int32(length(nzVal))
342-
343-
colPtr = CuArray(one(Int32):(N + one(Int32)))
344-
rowVal = CuArray(one(Int32):N)
345-
dims = (N, N)
346-
CuSparseMatrixCSC(colPtr, rowVal, nzVal, dims)
339+
SparseArrays.spdiagm(kv::Pair{<:Integer,<:CuVector}...) = _cuda_spdiagm(nothing, kv...)
340+
SparseArrays.spdiagm(m::Integer, n::Integer, kv::Pair{<:Integer,<:CuVector}...) = _cuda_spdiagm((Int(m),Int(n)), kv...)
341+
SparseArrays.spdiagm(v::CuVector) = _cuda_spdiagm(nothing, 0 => v)
342+
SparseArrays.spdiagm(m::Integer, n::Integer, v::CuVector) = _cuda_spdiagm((Int(m), Int(n)), 0 => v)
343+
344+
function _cuda_spdiagm(size, kv::Pair{<:Integer, <:CuVector}...)
345+
I, J, V, mmax, nmax = _cuda_spdiagm_internal(kv...)
346+
mnmax = max(mmax, nmax)
347+
m, n = something(size, (mnmax,mnmax))
348+
(m mmax && n nmax) || throw(DimensionMismatch("invalid size=$size"))
349+
return sparse(CuVector(I), CuVector(J), V, m, n)
350+
end
351+
352+
function _cuda_spdiagm_internal(kv::Pair{T,<:CuVector}...) where {T<:Integer}
353+
ncoeffs = 0
354+
for p in kv
355+
ncoeffs += SparseArrays._nnz(p.second)
356+
end
357+
I = Vector{T}(undef, ncoeffs)
358+
J = Vector{T}(undef, ncoeffs)
359+
V = CuArray{promote_type(map(x -> eltype(x.second), kv)...)}(undef, ncoeffs)
360+
i = 0
361+
m = 0
362+
n = 0
363+
for p in kv
364+
k = p.first
365+
v = p.second
366+
if k < 0
367+
row = -k
368+
col = 0
369+
elseif k > 0
370+
row = 0
371+
col = k
372+
else
373+
row = 0
374+
col = 0
375+
end
376+
numel = SparseArrays._nnz(v)
377+
r = 1+i:numel+i
378+
I_r, J_r = SparseArrays._indices(v, row, col)
379+
copyto!(view(I, r), I_r)
380+
copyto!(view(J, r), J_r)
381+
copyto!(view(V, r), v)
382+
veclen = length(v)
383+
m = max(m, row + veclen)
384+
n = max(n, col + veclen)
385+
i += numel
386+
end
387+
return I, J, V, m, n
347388
end
348389

349390
LinearAlgebra.issymmetric(M::Union{CuSparseMatrixCSC,CuSparseMatrixCSR}) = size(M, 1) == size(M, 2) ? norm(M - transpose(M), Inf) == 0 : false

test/libraries/cusparse/interfaces.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,23 @@ end
581581
cuda_vec = CuVector(ref_vec)
582582

583583
ref_spdiagm = spdiagm(ref_vec) # SparseArrays
584-
cuda_spdiagm = spdiagm(cuda_vec) # CuSparseMatrixCSC
584+
cuda_spdiagm = spdiagm(cuda_vec)
585+
586+
ref_cuda_sparse = CuSparseMatrixCSC(ref_spdiagm)
587+
588+
@test ref_cuda_sparse.rowVal == cuda_spdiagm.rowVal
589+
590+
@test ref_cuda_sparse.nzVal == cuda_spdiagm.nzVal
591+
592+
@test ref_cuda_sparse.colPtr == cuda_spdiagm.colPtr
593+
end
594+
595+
@testset "spdiagm(2 => CuVector{$elty})" for elty in [Float32, Float64, ComplexF32, ComplexF64]
596+
ref_vec = collect(elty, 100:121)
597+
cuda_vec = CuVector(ref_vec)
598+
599+
ref_spdiagm = spdiagm(2 => ref_vec) # SparseArrays
600+
cuda_spdiagm = spdiagm(2 => cuda_vec)
585601

586602
ref_cuda_sparse = CuSparseMatrixCSC(ref_spdiagm)
587603

0 commit comments

Comments
 (0)