8484 end
8585 end
8686end
87- @inline function _capturescalars (arg) # this definition is just an optimization (to bottom out the recursion slightly sooner)
87+ @inline function _capturescalars (arg)
88+ # this definition is just an optimization (to bottom out the recursion slightly sooner)
8889 if scalararg (arg)
8990 return (), () -> (arg,) # add scalararg
9091 elseif scalarwrappedarg (arg)
103104
104105# # COV_EXCL_START
105106# # iteration helpers
106-
107107"""
108108 CSRIterator{Ti}(row, args...)
109109
@@ -288,15 +288,20 @@ end
288288end
289289
290290# helpers to index a sparse or dense array
291- function _getindex (arg:: Union{CuSparseDeviceMatrixCSR,CuSparseDeviceMatrixCSC} , I, ptr)
291+ @inline function _getindex (arg:: Union {CuSparseDeviceMatrixCSR{Tv},
292+ CuSparseDeviceMatrixCSC{Tv},
293+ CuSparseDeviceVector{Tv}}, I, ptr):: Tv where {Tv}
292294 if ptr == 0
293- zero (eltype (arg) )
295+ return zero (Tv )
294296 else
295- @inbounds arg. nzVal[ptr]
297+ return @inbounds arg. nzVal[ptr]:: Tv
296298 end
297299end
298- _getindex (arg, I, ptr) = Broadcast. _broadcast_getindex (arg, I)
299300
301+ @inline function _getindex (arg:: CuDeviceArray{Tv} , I, ptr):: Tv where {Tv}
302+ return @inbounds arg[I]:: Tv
303+ end
304+ @inline _getindex (arg, I, ptr) = Broadcast. _broadcast_getindex (arg, I)
300305
301306# # sparse broadcast implementation
302307
@@ -305,8 +310,46 @@ iter_type(::Type{<:CuSparseMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti}
305310iter_type (:: Type{<:CuSparseDeviceMatrixCSC} , :: Type{Ti} ) where {Ti} = CSCIterator{Ti}
306311iter_type (:: Type{<:CuSparseDeviceMatrixCSR} , :: Type{Ti} ) where {Ti} = CSRIterator{Ti}
307312
313+ _has_row (A, offsets, row:: Int32 , fpreszeros:: Bool ) = fpreszeros ? 0 i32 : row
314+ _has_row (A:: CuDeviceArray , offsets, row:: Int32 , :: Bool ) = row
315+ function _has_row (A:: CuSparseDeviceVector , offsets, row:: Int32 , :: Bool ):: Int32
316+ for row_ix in 1 i32: length (A. iPtr)
317+ arg_row = @inbounds A. iPtr[row_ix]
318+ arg_row == row && return row_ix
319+ arg_row > row && break
320+ end
321+ return 0 i32
322+ end
323+
324+ function _get_my_row (first_row):: Int32
325+ row_ix = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
326+ return row_ix + first_row - 1 i32
327+ end
328+
329+ function compute_offsets_kernel (:: Type{<:CuSparseVector} , first_row:: Ti , last_row:: Ti ,
330+ fpreszeros:: Bool , offsets:: AbstractVector{Pair{Ti, NTuple{N, Ti}}} ,
331+ args... ) where {Ti, N}
332+ row = _get_my_row (first_row)
333+ row > last_row && return
334+
335+ # TODO load arg.iPtr slices into shared memory
336+ row_is_nnz = 0 i32
337+ arg_row_is_nnz = ntuple (Val (N)) do i
338+ arg = @inbounds args[i]
339+ _has_row (arg, offsets, row, fpreszeros):: Int32
340+ end
341+ row_is_nnz = 0 i32
342+ for i in 1 : N
343+ row_is_nnz |= @inbounds arg_row_is_nnz[i]
344+ end
345+ key = (row_is_nnz == 0 i32) ? typemax (Ti) : row
346+ @inbounds offsets[row - first_row + 1 i32] = key => arg_row_is_nnz
347+ return
348+ end
349+
308350# kernel to count the number of non-zeros in a row, to determine the row offsets
309- function compute_offsets_kernel (T:: Type{<:Union{CuSparseMatrixCSR, CuSparseMatrixCSC}} , offsets:: AbstractVector{Ti} ,
351+ function compute_offsets_kernel (T:: Type{<:Union{CuSparseMatrixCSR, CuSparseMatrixCSC}} ,
352+ offsets:: AbstractVector{Ti} ,
310353 args... ) where Ti
311354 # every thread processes an entire row
312355 leading_dim = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
@@ -331,8 +374,30 @@ function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatri
331374 return
332375end
333376
334- # broadcast kernels that iterate the elements of sparse arrays
335- function sparse_to_sparse_broadcast_kernel (f, output:: T , offsets:: Union{AbstractVector,Nothing} , args... ) where {Ti, T<: Union{CuSparseDeviceMatrixCSR{<:Any,Ti},CuSparseDeviceMatrixCSC{<:Any,Ti}} }
377+ function sparse_to_sparse_broadcast_kernel (f:: F , output:: CuSparseDeviceVector{Tv,Ti} ,
378+ offsets:: AbstractVector{Pair{Ti, NTuple{N, Ti}}} ,
379+ args... ) where {Tv, Ti, N, F}
380+ row_ix = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
381+ row_ix > output. nnz && return
382+ row_and_ptrs = @inbounds offsets[row_ix]
383+ row = @inbounds row_and_ptrs[1 ]
384+ arg_ptrs = @inbounds row_and_ptrs[2 ]
385+ vals = ntuple (Val (N)) do i
386+ arg = @inbounds args[i]
387+ # ptr is 0 if the sparse vector doesn't have an element at this row
388+ # ptr is 0 if the arg is a scalar AND f preserves zeros
389+ ptr = @inbounds arg_ptrs[i]
390+ _getindex (arg, row, ptr):: Tv
391+ end
392+ output_val = f (vals... )
393+ @inbounds output. iPtr[row_ix] = row
394+ @inbounds output. nzVal[row_ix] = output_val
395+ return
396+ end
397+
398+ function sparse_to_sparse_broadcast_kernel (f, output:: T , offsets:: Union{AbstractVector,Nothing} ,
399+ args... ) where {Ti, T<: Union {CuSparseDeviceMatrixCSR{<: Any ,Ti},
400+ CuSparseDeviceMatrixCSC{<: Any ,Ti}}}
336401 # every thread processes an entire row
337402 leading_dim = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
338403 leading_dim_size = output isa CuSparseDeviceMatrixCSR ? size (output, 1 ) : size (output, 2 )
@@ -345,7 +410,7 @@ function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{Abstract
345410 # fetch the row offset, and write it to the output
346411 @inbounds begin
347412 output_ptr = output_ptrs[leading_dim] = offsets[leading_dim]
348- if leading_dim == leading_dim_size
413+ if leading_dim == leading_dim_size
349414 output_ptrs[leading_dim+ 1 i32] = offsets[leading_dim+ 1 i32]
350415 end
351416 end
@@ -368,7 +433,8 @@ function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{Abstract
368433
369434 return
370435end
371- function sparse_to_dense_broadcast_kernel (T:: Type{<:Union{CuSparseMatrixCSR{Tv, Ti}, CuSparseMatrixCSC{Tv, Ti}}} , f,
436+ function sparse_to_dense_broadcast_kernel (T:: Type {<: Union {CuSparseMatrixCSR{Tv, Ti},
437+ CuSparseMatrixCSC{Tv, Ti}}}, f,
372438 output:: CuDeviceArray , args... ) where {Tv, Ti}
373439 # every thread processes an entire row
374440 leading_dim = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
@@ -392,6 +458,28 @@ function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv,
392458
393459 return
394460end
461+
462+ function sparse_to_dense_broadcast_kernel (:: Type{<:CuSparseVector} , f:: F ,
463+ output:: CuDeviceArray{Tv} ,
464+ offsets:: AbstractVector{Pair{Ti, NTuple{N, Ti}}} ,
465+ args... ) where {Tv, F, N, Ti}
466+ # every thread processes an entire row
467+ row_ix = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
468+ row_ix > length (output) && return
469+ row_and_ptrs = @inbounds offsets[row_ix]
470+ row = @inbounds row_and_ptrs[1 ]
471+ arg_ptrs = @inbounds row_and_ptrs[2 ]
472+ vals = ntuple (Val (length (args))) do i
473+ arg = @inbounds args[i]
474+ # ptr is 0 if the sparse vector doesn't have an element at this row
475+ # ptr is row if the arg is dense OR a scalar with non-zero-preserving f
476+ # ptr is 0 if the arg is a scalar AND f preserves zeros
477+ ptr = @inbounds arg_ptrs[i]
478+ _getindex (arg, row, ptr):: Tv
479+ end
480+ @inbounds output[row] = f (vals... )
481+ return
482+ end
395483# # COV_EXCL_STOP
396484
397485function Broadcast. copy (bc:: Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyle}} )
@@ -405,12 +493,14 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
405493 error (" broadcast with multiple types of sparse arrays ($(join (sparse_types, " , " )) ) is not supported" )
406494 end
407495 sparse_typ = typeof (bc. args[first (sparse_args)])
408- sparse_typ <: Union{CuSparseMatrixCSR,CuSparseMatrixCSC} ||
409- error (" broadcast with sparse arrays is currently only implemented for CSR and CSC matrices" )
496+ sparse_typ <: Union{CuSparseMatrixCSR,CuSparseMatrixCSC,CuSparseVector } ||
497+ error (" broadcast with sparse arrays is currently only implemented for vectors and CSR and CSC matrices" )
410498 Ti = if sparse_typ <: CuSparseMatrixCSR
411499 reduce (promote_type, map (i-> eltype (bc. args[i]. rowPtr), sparse_args))
412500 elseif sparse_typ <: CuSparseMatrixCSC
413501 reduce (promote_type, map (i-> eltype (bc. args[i]. colPtr), sparse_args))
502+ elseif sparse_typ <: CuSparseVector
503+ reduce (promote_type, map (i-> eltype (bc. args[i]. iPtr), sparse_args))
414504 end
415505
416506 # determine the output type
@@ -433,23 +523,32 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
433523
434524 # the kernels below parallelize across rows or cols, not elements, so it's unlikely
435525 # we'll launch many threads. to maximize utilization, parallelize across blocks first.
436- rows, cols = size (bc)
526+ rows, cols = get ( size (bc), 1 , 1 ), get ( size (bc), 2 , 1 ) # `size(bc, ::Int)` is missing
437527 function compute_launch_config (kernel)
438528 config = launch_configuration (kernel. fun)
439529 if sparse_typ <: CuSparseMatrixCSR
440530 threads = min (rows, config. threads)
441- blocks = max (cld (rows, threads), config. blocks)
531+ blocks = max (cld (rows, threads), config. blocks)
442532 threads = cld (rows, blocks)
443533 elseif sparse_typ <: CuSparseMatrixCSC
444534 threads = min (cols, config. threads)
445- blocks = max (cld (cols, threads), config. blocks)
535+ blocks = max (cld (cols, threads), config. blocks)
446536 threads = cld (cols, blocks)
537+ elseif sparse_typ <: CuSparseVector
538+ threads = 512
539+ blocks = max (cld (rows, threads), config. blocks)
447540 end
448541 (; threads, blocks)
449542 end
450-
543+ # for CuSparseVec, figure out the actual row range we need to address, e.g. if m = 2^20
544+ # but the only rows present in any sparse vector input are between 2 and 128, no need to
545+ # launch massive threads.
546+ # TODO : use the difference here to set the thread count
547+ overall_first_row = one (Ti)
548+ overall_last_row = Ti (rows)
549+ offsets = nothing
451550 # allocate the output container
452- if ! fpreszeros
551+ if ! fpreszeros && sparse_typ <: Union{CuSparseMatrixCSR, CuSparseMatrixCSC}
453552 # either we have dense inputs, or the function isn't preserving zeros,
454553 # so use a dense output to broadcast into.
455554 output = CuArray {Tv} (undef, size (bc))
@@ -466,20 +565,20 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
466565 end
467566 end
468567 broadcast! (bc. f, output, nonsparse_args... )
469- elseif length (sparse_args) == 1
568+ elseif length (sparse_args) == 1 && sparse_typ <: Union{CuSparseMatrixCSR, CuSparseMatrixCSC}
470569 # we only have a single sparse input, so we can reuse its structure for the output.
471570 # this avoids a kernel launch and costly synchronization.
472571 sparse_arg = bc. args[first (sparse_args)]
473572 if sparse_typ <: CuSparseMatrixCSR
474573 offsets = rowPtr = sparse_arg. rowPtr
475- colVal = similar (sparse_arg. colVal)
476- nzVal = similar (sparse_arg. nzVal, Tv)
477- output = CuSparseMatrixCSR (rowPtr, colVal, nzVal, size (bc))
574+ colVal = similar (sparse_arg. colVal)
575+ nzVal = similar (sparse_arg. nzVal, Tv)
576+ output = CuSparseMatrixCSR (rowPtr, colVal, nzVal, size (bc))
478577 elseif sparse_typ <: CuSparseMatrixCSC
479578 offsets = colPtr = sparse_arg. colPtr
480- rowVal = similar (sparse_arg. rowVal)
481- nzVal = similar (sparse_arg. nzVal, Tv)
482- output = CuSparseMatrixCSC (colPtr, rowVal, nzVal, size (bc))
579+ rowVal = similar (sparse_arg. rowVal)
580+ nzVal = similar (sparse_arg. nzVal, Tv)
581+ output = CuSparseMatrixCSC (colPtr, rowVal, nzVal, size (bc))
483582 end
484583 # NOTE: we don't use CUSPARSE's similar, because that copies the structure arrays,
485584 # while we do that in our kernel (for consistency with other code paths)
@@ -490,43 +589,79 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
490589 CuArray {Ti} (undef, rows+ 1 )
491590 elseif sparse_typ <: CuSparseMatrixCSC
492591 CuArray {Ti} (undef, cols+ 1 )
592+ elseif sparse_typ <: CuSparseVector
593+ CUDA. @allowscalar begin
594+ arg_first_rows = ntuple (Val (length (bc. args))) do i
595+ bc. args[i] isa CuSparseVector && return bc. args[i]. iPtr[1 ]
596+ return one (Ti)
597+ end
598+ arg_last_rows = ntuple (Val (length (bc. args))) do i
599+ bc. args[i] isa CuSparseVector && return bc. args[i]. iPtr[end ]
600+ return Ti (rows)
601+ end
602+ end
603+ overall_first_row = min (arg_first_rows... )
604+ overall_last_row = max (arg_last_rows... )
605+ CuVector {Pair{Ti, NTuple{length(bc.args), Ti}}} (undef, overall_last_row - overall_first_row + 1 )
493606 end
494607 let
495- args = (sparse_typ, offsets, bc. args... )
608+ args = if sparse_typ <: CuSparseVector
609+ (sparse_typ, overall_first_row, overall_last_row, fpreszeros, offsets, bc. args... )
610+ else
611+ (sparse_typ, offsets, bc. args... )
612+ end
496613 kernel = @cuda launch= false compute_offsets_kernel (args... )
497614 threads, blocks = compute_launch_config (kernel)
498615 kernel (args... ; threads, blocks)
499616 end
500-
501617 # accumulate these values so that we can use them directly as row pointer offsets,
502618 # as well as to get the total nnz count to allocate the sparse output array.
503619 # cusparseXcsrgeam2Nnz computes this in one go, but it doesn't seem worth the effort
504- accumulate! (Base. add_sum, offsets, offsets)
505- total_nnz = @allowscalar last (offsets[end ]) - 1
506-
620+ if ! (sparse_typ <: CuSparseVector )
621+ accumulate! (Base. add_sum, offsets, offsets)
622+ total_nnz = @allowscalar last (offsets[end ]) - 1
623+ else
624+ sort! (offsets; by= first)
625+ total_nnz = mapreduce (x-> first (x) != typemax (first (x)), + , offsets)
626+ end
507627 output = if sparse_typ <: CuSparseMatrixCSR
508628 colVal = CuArray {Ti} (undef, total_nnz)
509- nzVal = CuArray {Tv} (undef, total_nnz)
629+ nzVal = CuArray {Tv} (undef, total_nnz)
510630 CuSparseMatrixCSR (offsets, colVal, nzVal, size (bc))
511631 elseif sparse_typ <: CuSparseMatrixCSC
512632 rowVal = CuArray {Ti} (undef, total_nnz)
513- nzVal = CuArray {Tv} (undef, total_nnz)
633+ nzVal = CuArray {Tv} (undef, total_nnz)
514634 CuSparseMatrixCSC (offsets, rowVal, nzVal, size (bc))
635+ elseif sparse_typ <: CuSparseVector && ! fpreszeros
636+ CuArray {Tv} (undef, size (bc))
637+ elseif sparse_typ <: CuSparseVector && fpreszeros
638+ iPtr = CUDA. zeros (Ti, total_nnz)
639+ nzVal = CUDA. zeros (Tv, total_nnz)
640+ CuSparseVector (iPtr, nzVal, rows)
641+ end
642+ if sparse_typ <: CuSparseVector && ! fpreszeros
643+ nonsparse_args = map (bc. args) do arg
644+ # NOTE: this assumes the broadcst is flattened, but not yet preprocessed
645+ if arg isa AbstractCuSparseArray
646+ zero (eltype (arg))
647+ else
648+ arg
649+ end
650+ end
651+ broadcast! (bc. f, output, nonsparse_args... )
515652 end
516653 end
517-
518654 # perform the actual broadcast
519655 if output isa AbstractCuSparseArray
520- args = (bc. f, output, offsets, bc. args... )
656+ args = (bc. f, output, offsets, bc. args... )
521657 kernel = @cuda launch= false sparse_to_sparse_broadcast_kernel (args... )
522- threads, blocks = compute_launch_config (kernel)
523- kernel (args... ; threads, blocks)
524658 else
525- args = (sparse_typ, bc. f, output, bc. args... )
659+ args = sparse_typ <: CuSparseVector ? (sparse_typ, bc. f, output, offsets, bc. args... ) :
660+ (sparse_typ, bc. f, output, bc. args... )
526661 kernel = @cuda launch= false sparse_to_dense_broadcast_kernel (args... )
527- threads, blocks = compute_launch_config (kernel)
528- kernel (args... ; threads, blocks)
529662 end
663+ threads, blocks = compute_launch_config (kernel)
664+ kernel (args... ; threads, blocks)
530665
531666 return output
532667end
0 commit comments