@@ -336,14 +336,55 @@ function SparseArrays.sparsevec(I::CuArray{Ti}, V::CuArray{Tv}, n::Integer) wher
336
336
CuSparseVector (I, V, n)
337
337
end
338
338
339
- function SparseArrays. spdiagm (v:: CuVector{Tv} ) where {Tv}
340
- nzVal = v
341
- N = Int32 (length (nzVal))
342
-
343
- colPtr = CuArray (one (Int32): (N + one (Int32)))
344
- rowVal = CuArray (one (Int32): N)
345
- dims = (N, N)
346
- CuSparseMatrixCSC (colPtr, rowVal, nzVal, dims)
339
+ SparseArrays. spdiagm (kv:: Pair{<:Integer,<:CuVector} ...) = _cuda_spdiagm (nothing , kv... )
340
+ SparseArrays. spdiagm (m:: Integer , n:: Integer , kv:: Pair{<:Integer,<:CuVector} ...) = _cuda_spdiagm ((Int (m),Int (n)), kv... )
341
+ SparseArrays. spdiagm (v:: CuVector ) = _cuda_spdiagm (nothing , 0 => v)
342
+ SparseArrays. spdiagm (m:: Integer , n:: Integer , v:: CuVector ) = _cuda_spdiagm ((Int (m), Int (n)), 0 => v)
343
+
344
+ function _cuda_spdiagm (size, kv:: Pair{<:Integer, <:CuVector} ...)
345
+ I, J, V, mmax, nmax = _cuda_spdiagm_internal (kv... )
346
+ mnmax = max (mmax, nmax)
347
+ m, n = something (size, (mnmax,mnmax))
348
+ (m ≥ mmax && n ≥ nmax) || throw (DimensionMismatch (" invalid size=$size " ))
349
+ return sparse (CuVector (I), CuVector (J), V, m, n)
350
+ end
351
+
352
+ function _cuda_spdiagm_internal (kv:: Pair{T,<:CuVector} ...) where {T<: Integer }
353
+ ncoeffs = 0
354
+ for p in kv
355
+ ncoeffs += SparseArrays. _nnz (p. second)
356
+ end
357
+ I = Vector {T} (undef, ncoeffs)
358
+ J = Vector {T} (undef, ncoeffs)
359
+ V = CuArray {promote_type(map(x -> eltype(x.second), kv)...)} (undef, ncoeffs)
360
+ i = 0
361
+ m = 0
362
+ n = 0
363
+ for p in kv
364
+ k = p. first
365
+ v = p. second
366
+ if k < 0
367
+ row = - k
368
+ col = 0
369
+ elseif k > 0
370
+ row = 0
371
+ col = k
372
+ else
373
+ row = 0
374
+ col = 0
375
+ end
376
+ numel = SparseArrays. _nnz (v)
377
+ r = 1 + i: numel+ i
378
+ I_r, J_r = SparseArrays. _indices (v, row, col)
379
+ copyto! (view (I, r), I_r)
380
+ copyto! (view (J, r), J_r)
381
+ copyto! (view (V, r), v)
382
+ veclen = length (v)
383
+ m = max (m, row + veclen)
384
+ n = max (n, col + veclen)
385
+ i += numel
386
+ end
387
+ return I, J, V, m, n
347
388
end
348
389
349
390
LinearAlgebra. issymmetric (M:: Union{CuSparseMatrixCSC,CuSparseMatrixCSR} ) = size (M, 1 ) == size (M, 2 ) ? norm (M - transpose (M), Inf ) == 0 : false
0 commit comments