84
84
end
85
85
end
86
86
end
87
- @inline function _capturescalars (arg) # this definition is just an optimization (to bottom out the recursion slightly sooner)
87
+ @inline function _capturescalars (arg)
88
+ # this definition is just an optimization (to bottom out the recursion slightly sooner)
88
89
if scalararg (arg)
89
90
return (), () -> (arg,) # add scalararg
90
91
elseif scalarwrappedarg (arg)
103
104
104
105
# # COV_EXCL_START
105
106
# # iteration helpers
106
-
107
107
"""
108
108
CSRIterator{Ti}(row, args...)
109
109
@@ -288,15 +288,20 @@ end
288
288
end
289
289
290
290
# helpers to index a sparse or dense array
291
- function _getindex (arg:: Union{CuSparseDeviceMatrixCSR,CuSparseDeviceMatrixCSC} , I, ptr)
291
+ @inline function _getindex (arg:: Union {CuSparseDeviceMatrixCSR{Tv},
292
+ CuSparseDeviceMatrixCSC{Tv},
293
+ CuSparseDeviceVector{Tv}}, I, ptr):: Tv where {Tv}
292
294
if ptr == 0
293
- zero (eltype (arg) )
295
+ return zero (Tv )
294
296
else
295
- @inbounds arg. nzVal[ptr]
297
+ return @inbounds arg. nzVal[ptr]:: Tv
296
298
end
297
299
end
298
- _getindex (arg, I, ptr) = Broadcast. _broadcast_getindex (arg, I)
299
300
301
+ @inline function _getindex (arg:: CuDeviceArray{Tv} , I, ptr):: Tv where {Tv}
302
+ return @inbounds arg[I]:: Tv
303
+ end
304
+ @inline _getindex (arg, I, ptr) = Broadcast. _broadcast_getindex (arg, I)
300
305
301
306
# # sparse broadcast implementation
302
307
@@ -305,8 +310,46 @@ iter_type(::Type{<:CuSparseMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti}
305
310
iter_type (:: Type{<:CuSparseDeviceMatrixCSC} , :: Type{Ti} ) where {Ti} = CSCIterator{Ti}
306
311
iter_type (:: Type{<:CuSparseDeviceMatrixCSR} , :: Type{Ti} ) where {Ti} = CSRIterator{Ti}
307
312
313
+ _has_row (A, offsets, row:: Int32 , fpreszeros:: Bool ) = fpreszeros ? 0 i32 : row
314
+ _has_row (A:: CuDeviceArray , offsets, row:: Int32 , :: Bool ) = row
315
+ function _has_row (A:: CuSparseDeviceVector , offsets, row:: Int32 , :: Bool ):: Int32
316
+ for row_ix in 1 i32: length (A. iPtr)
317
+ arg_row = @inbounds A. iPtr[row_ix]
318
+ arg_row == row && return row_ix
319
+ arg_row > row && break
320
+ end
321
+ return 0 i32
322
+ end
323
+
324
+ function _get_my_row (first_row):: Int32
325
+ row_ix = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
326
+ return row_ix + first_row - 1 i32
327
+ end
328
+
329
+ function compute_offsets_kernel (:: Type{<:CuSparseVector} , first_row:: Ti , last_row:: Ti ,
330
+ fpreszeros:: Bool , offsets:: AbstractVector{Pair{Ti, NTuple{N, Ti}}} ,
331
+ args... ) where {Ti, N}
332
+ row = _get_my_row (first_row)
333
+ row > last_row && return
334
+
335
+ # TODO load arg.iPtr slices into shared memory
336
+ row_is_nnz = 0 i32
337
+ arg_row_is_nnz = ntuple (Val (N)) do i
338
+ arg = @inbounds args[i]
339
+ _has_row (arg, offsets, row, fpreszeros):: Int32
340
+ end
341
+ row_is_nnz = 0 i32
342
+ for i in 1 : N
343
+ row_is_nnz |= @inbounds arg_row_is_nnz[i]
344
+ end
345
+ key = (row_is_nnz == 0 i32) ? typemax (Ti) : row
346
+ @inbounds offsets[row - first_row + 1 i32] = key => arg_row_is_nnz
347
+ return
348
+ end
349
+
308
350
# kernel to count the number of non-zeros in a row, to determine the row offsets
309
- function compute_offsets_kernel (T:: Type{<:Union{CuSparseMatrixCSR, CuSparseMatrixCSC}} , offsets:: AbstractVector{Ti} ,
351
+ function compute_offsets_kernel (T:: Type{<:Union{CuSparseMatrixCSR, CuSparseMatrixCSC}} ,
352
+ offsets:: AbstractVector{Ti} ,
310
353
args... ) where Ti
311
354
# every thread processes an entire row
312
355
leading_dim = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
@@ -331,8 +374,30 @@ function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatri
331
374
return
332
375
end
333
376
334
- # broadcast kernels that iterate the elements of sparse arrays
335
- function sparse_to_sparse_broadcast_kernel (f, output:: T , offsets:: Union{AbstractVector,Nothing} , args... ) where {Ti, T<: Union{CuSparseDeviceMatrixCSR{<:Any,Ti},CuSparseDeviceMatrixCSC{<:Any,Ti}} }
377
+ function sparse_to_sparse_broadcast_kernel (f:: F , output:: CuSparseDeviceVector{Tv,Ti} ,
378
+ offsets:: AbstractVector{Pair{Ti, NTuple{N, Ti}}} ,
379
+ args... ) where {Tv, Ti, N, F}
380
+ row_ix = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
381
+ row_ix > output. nnz && return
382
+ row_and_ptrs = @inbounds offsets[row_ix]
383
+ row = @inbounds row_and_ptrs[1 ]
384
+ arg_ptrs = @inbounds row_and_ptrs[2 ]
385
+ vals = ntuple (Val (N)) do i
386
+ arg = @inbounds args[i]
387
+ # ptr is 0 if the sparse vector doesn't have an element at this row
388
+ # ptr is 0 if the arg is a scalar AND f preserves zeros
389
+ ptr = @inbounds arg_ptrs[i]
390
+ _getindex (arg, row, ptr):: Tv
391
+ end
392
+ output_val = f (vals... )
393
+ @inbounds output. iPtr[row_ix] = row
394
+ @inbounds output. nzVal[row_ix] = output_val
395
+ return
396
+ end
397
+
398
+ function sparse_to_sparse_broadcast_kernel (f, output:: T , offsets:: Union{AbstractVector,Nothing} ,
399
+ args... ) where {Ti, T<: Union {CuSparseDeviceMatrixCSR{<: Any ,Ti},
400
+ CuSparseDeviceMatrixCSC{<: Any ,Ti}}}
336
401
# every thread processes an entire row
337
402
leading_dim = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
338
403
leading_dim_size = output isa CuSparseDeviceMatrixCSR ? size (output, 1 ) : size (output, 2 )
@@ -345,7 +410,7 @@ function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{Abstract
345
410
# fetch the row offset, and write it to the output
346
411
@inbounds begin
347
412
output_ptr = output_ptrs[leading_dim] = offsets[leading_dim]
348
- if leading_dim == leading_dim_size
413
+ if leading_dim == leading_dim_size
349
414
output_ptrs[leading_dim+ 1 i32] = offsets[leading_dim+ 1 i32]
350
415
end
351
416
end
@@ -368,7 +433,8 @@ function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{Abstract
368
433
369
434
return
370
435
end
371
- function sparse_to_dense_broadcast_kernel (T:: Type{<:Union{CuSparseMatrixCSR{Tv, Ti}, CuSparseMatrixCSC{Tv, Ti}}} , f,
436
+ function sparse_to_dense_broadcast_kernel (T:: Type {<: Union {CuSparseMatrixCSR{Tv, Ti},
437
+ CuSparseMatrixCSC{Tv, Ti}}}, f,
372
438
output:: CuDeviceArray , args... ) where {Tv, Ti}
373
439
# every thread processes an entire row
374
440
leading_dim = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
@@ -392,6 +458,28 @@ function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv,
392
458
393
459
return
394
460
end
461
+
462
+ function sparse_to_dense_broadcast_kernel (:: Type{<:CuSparseVector} , f:: F ,
463
+ output:: CuDeviceArray{Tv} ,
464
+ offsets:: AbstractVector{Pair{Ti, NTuple{N, Ti}}} ,
465
+ args... ) where {Tv, F, N, Ti}
466
+ # every thread processes an entire row
467
+ row_ix = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
468
+ row_ix > length (output) && return
469
+ row_and_ptrs = @inbounds offsets[row_ix]
470
+ row = @inbounds row_and_ptrs[1 ]
471
+ arg_ptrs = @inbounds row_and_ptrs[2 ]
472
+ vals = ntuple (Val (length (args))) do i
473
+ arg = @inbounds args[i]
474
+ # ptr is 0 if the sparse vector doesn't have an element at this row
475
+ # ptr is row if the arg is dense OR a scalar with non-zero-preserving f
476
+ # ptr is 0 if the arg is a scalar AND f preserves zeros
477
+ ptr = @inbounds arg_ptrs[i]
478
+ _getindex (arg, row, ptr):: Tv
479
+ end
480
+ @inbounds output[row] = f (vals... )
481
+ return
482
+ end
395
483
# # COV_EXCL_STOP
396
484
397
485
function Broadcast. copy (bc:: Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyle}} )
@@ -405,12 +493,14 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
405
493
error (" broadcast with multiple types of sparse arrays ($(join (sparse_types, " , " )) ) is not supported" )
406
494
end
407
495
sparse_typ = typeof (bc. args[first (sparse_args)])
408
- sparse_typ <: Union{CuSparseMatrixCSR,CuSparseMatrixCSC} ||
409
- error (" broadcast with sparse arrays is currently only implemented for CSR and CSC matrices" )
496
+ sparse_typ <: Union{CuSparseMatrixCSR,CuSparseMatrixCSC,CuSparseVector } ||
497
+ error (" broadcast with sparse arrays is currently only implemented for vectors and CSR and CSC matrices" )
410
498
Ti = if sparse_typ <: CuSparseMatrixCSR
411
499
reduce (promote_type, map (i-> eltype (bc. args[i]. rowPtr), sparse_args))
412
500
elseif sparse_typ <: CuSparseMatrixCSC
413
501
reduce (promote_type, map (i-> eltype (bc. args[i]. colPtr), sparse_args))
502
+ elseif sparse_typ <: CuSparseVector
503
+ reduce (promote_type, map (i-> eltype (bc. args[i]. iPtr), sparse_args))
414
504
end
415
505
416
506
# determine the output type
@@ -433,23 +523,32 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
433
523
434
524
# the kernels below parallelize across rows or cols, not elements, so it's unlikely
435
525
# we'll launch many threads. to maximize utilization, parallelize across blocks first.
436
- rows, cols = size (bc)
526
+ rows, cols = get ( size (bc), 1 , 1 ), get ( size (bc), 2 , 1 ) # `size(bc, ::Int)` is missing
437
527
function compute_launch_config (kernel)
438
528
config = launch_configuration (kernel. fun)
439
529
if sparse_typ <: CuSparseMatrixCSR
440
530
threads = min (rows, config. threads)
441
- blocks = max (cld (rows, threads), config. blocks)
531
+ blocks = max (cld (rows, threads), config. blocks)
442
532
threads = cld (rows, blocks)
443
533
elseif sparse_typ <: CuSparseMatrixCSC
444
534
threads = min (cols, config. threads)
445
- blocks = max (cld (cols, threads), config. blocks)
535
+ blocks = max (cld (cols, threads), config. blocks)
446
536
threads = cld (cols, blocks)
537
+ elseif sparse_typ <: CuSparseVector
538
+ threads = 512
539
+ blocks = max (cld (rows, threads), config. blocks)
447
540
end
448
541
(; threads, blocks)
449
542
end
450
-
543
+ # for CuSparseVec, figure out the actual row range we need to address, e.g. if m = 2^20
544
+ # but the only rows present in any sparse vector input are between 2 and 128, no need to
545
+ # launch massive threads.
546
+ # TODO : use the difference here to set the thread count
547
+ overall_first_row = one (Ti)
548
+ overall_last_row = Ti (rows)
549
+ offsets = nothing
451
550
# allocate the output container
452
- if ! fpreszeros
551
+ if ! fpreszeros && sparse_typ <: Union{CuSparseMatrixCSR, CuSparseMatrixCSC}
453
552
# either we have dense inputs, or the function isn't preserving zeros,
454
553
# so use a dense output to broadcast into.
455
554
output = CuArray {Tv} (undef, size (bc))
@@ -466,20 +565,20 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
466
565
end
467
566
end
468
567
broadcast! (bc. f, output, nonsparse_args... )
469
- elseif length (sparse_args) == 1
568
+ elseif length (sparse_args) == 1 && sparse_typ <: Union{CuSparseMatrixCSR, CuSparseMatrixCSC}
470
569
# we only have a single sparse input, so we can reuse its structure for the output.
471
570
# this avoids a kernel launch and costly synchronization.
472
571
sparse_arg = bc. args[first (sparse_args)]
473
572
if sparse_typ <: CuSparseMatrixCSR
474
573
offsets = rowPtr = sparse_arg. rowPtr
475
- colVal = similar (sparse_arg. colVal)
476
- nzVal = similar (sparse_arg. nzVal, Tv)
477
- output = CuSparseMatrixCSR (rowPtr, colVal, nzVal, size (bc))
574
+ colVal = similar (sparse_arg. colVal)
575
+ nzVal = similar (sparse_arg. nzVal, Tv)
576
+ output = CuSparseMatrixCSR (rowPtr, colVal, nzVal, size (bc))
478
577
elseif sparse_typ <: CuSparseMatrixCSC
479
578
offsets = colPtr = sparse_arg. colPtr
480
- rowVal = similar (sparse_arg. rowVal)
481
- nzVal = similar (sparse_arg. nzVal, Tv)
482
- output = CuSparseMatrixCSC (colPtr, rowVal, nzVal, size (bc))
579
+ rowVal = similar (sparse_arg. rowVal)
580
+ nzVal = similar (sparse_arg. nzVal, Tv)
581
+ output = CuSparseMatrixCSC (colPtr, rowVal, nzVal, size (bc))
483
582
end
484
583
# NOTE: we don't use CUSPARSE's similar, because that copies the structure arrays,
485
584
# while we do that in our kernel (for consistency with other code paths)
@@ -490,43 +589,79 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
490
589
CuArray {Ti} (undef, rows+ 1 )
491
590
elseif sparse_typ <: CuSparseMatrixCSC
492
591
CuArray {Ti} (undef, cols+ 1 )
592
+ elseif sparse_typ <: CuSparseVector
593
+ CUDA. @allowscalar begin
594
+ arg_first_rows = ntuple (Val (length (bc. args))) do i
595
+ bc. args[i] isa CuSparseVector && return bc. args[i]. iPtr[1 ]
596
+ return one (Ti)
597
+ end
598
+ arg_last_rows = ntuple (Val (length (bc. args))) do i
599
+ bc. args[i] isa CuSparseVector && return bc. args[i]. iPtr[end ]
600
+ return Ti (rows)
601
+ end
602
+ end
603
+ overall_first_row = min (arg_first_rows... )
604
+ overall_last_row = max (arg_last_rows... )
605
+ CuVector {Pair{Ti, NTuple{length(bc.args), Ti}}} (undef, overall_last_row - overall_first_row + 1 )
493
606
end
494
607
let
495
- args = (sparse_typ, offsets, bc. args... )
608
+ args = if sparse_typ <: CuSparseVector
609
+ (sparse_typ, overall_first_row, overall_last_row, fpreszeros, offsets, bc. args... )
610
+ else
611
+ (sparse_typ, offsets, bc. args... )
612
+ end
496
613
kernel = @cuda launch= false compute_offsets_kernel (args... )
497
614
threads, blocks = compute_launch_config (kernel)
498
615
kernel (args... ; threads, blocks)
499
616
end
500
-
501
617
# accumulate these values so that we can use them directly as row pointer offsets,
502
618
# as well as to get the total nnz count to allocate the sparse output array.
503
619
# cusparseXcsrgeam2Nnz computes this in one go, but it doesn't seem worth the effort
504
- accumulate! (Base. add_sum, offsets, offsets)
505
- total_nnz = @allowscalar last (offsets[end ]) - 1
506
-
620
+ if ! (sparse_typ <: CuSparseVector )
621
+ accumulate! (Base. add_sum, offsets, offsets)
622
+ total_nnz = @allowscalar last (offsets[end ]) - 1
623
+ else
624
+ sort! (offsets; by= first)
625
+ total_nnz = mapreduce (x-> first (x) != typemax (first (x)), + , offsets)
626
+ end
507
627
output = if sparse_typ <: CuSparseMatrixCSR
508
628
colVal = CuArray {Ti} (undef, total_nnz)
509
- nzVal = CuArray {Tv} (undef, total_nnz)
629
+ nzVal = CuArray {Tv} (undef, total_nnz)
510
630
CuSparseMatrixCSR (offsets, colVal, nzVal, size (bc))
511
631
elseif sparse_typ <: CuSparseMatrixCSC
512
632
rowVal = CuArray {Ti} (undef, total_nnz)
513
- nzVal = CuArray {Tv} (undef, total_nnz)
633
+ nzVal = CuArray {Tv} (undef, total_nnz)
514
634
CuSparseMatrixCSC (offsets, rowVal, nzVal, size (bc))
635
+ elseif sparse_typ <: CuSparseVector && ! fpreszeros
636
+ CuArray {Tv} (undef, size (bc))
637
+ elseif sparse_typ <: CuSparseVector && fpreszeros
638
+ iPtr = CUDA. zeros (Ti, total_nnz)
639
+ nzVal = CUDA. zeros (Tv, total_nnz)
640
+ CuSparseVector (iPtr, nzVal, rows)
641
+ end
642
+ if sparse_typ <: CuSparseVector && ! fpreszeros
643
+ nonsparse_args = map (bc. args) do arg
644
+ # NOTE: this assumes the broadcst is flattened, but not yet preprocessed
645
+ if arg isa AbstractCuSparseArray
646
+ zero (eltype (arg))
647
+ else
648
+ arg
649
+ end
650
+ end
651
+ broadcast! (bc. f, output, nonsparse_args... )
515
652
end
516
653
end
517
-
518
654
# perform the actual broadcast
519
655
if output isa AbstractCuSparseArray
520
- args = (bc. f, output, offsets, bc. args... )
656
+ args = (bc. f, output, offsets, bc. args... )
521
657
kernel = @cuda launch= false sparse_to_sparse_broadcast_kernel (args... )
522
- threads, blocks = compute_launch_config (kernel)
523
- kernel (args... ; threads, blocks)
524
658
else
525
- args = (sparse_typ, bc. f, output, bc. args... )
659
+ args = sparse_typ <: CuSparseVector ? (sparse_typ, bc. f, output, offsets, bc. args... ) :
660
+ (sparse_typ, bc. f, output, bc. args... )
526
661
kernel = @cuda launch= false sparse_to_dense_broadcast_kernel (args... )
527
- threads, blocks = compute_launch_config (kernel)
528
- kernel (args... ; threads, blocks)
529
662
end
663
+ threads, blocks = compute_launch_config (kernel)
664
+ kernel (args... ; threads, blocks)
530
665
531
666
return output
532
667
end
0 commit comments