Skip to content

Commit 7a00a7e

Browse files
kshyattmaleadt
andauthored
Initial implementation of broadcast for CuSparseVector (#2733)
Co-authored-by: Tim Besard <[email protected]>
1 parent bf1c589 commit 7a00a7e

File tree

3 files changed

+254
-45
lines changed

3 files changed

+254
-45
lines changed

lib/cusparse/broadcast.jl

Lines changed: 174 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ end
8484
end
8585
end
8686
end
87-
@inline function _capturescalars(arg) # this definition is just an optimization (to bottom out the recursion slightly sooner)
87+
@inline function _capturescalars(arg)
88+
# this definition is just an optimization (to bottom out the recursion slightly sooner)
8889
if scalararg(arg)
8990
return (), () -> (arg,) # add scalararg
9091
elseif scalarwrappedarg(arg)
@@ -103,7 +104,6 @@ end
103104

104105
## COV_EXCL_START
105106
## iteration helpers
106-
107107
"""
108108
CSRIterator{Ti}(row, args...)
109109
@@ -288,15 +288,20 @@ end
288288
end
289289

290290
# helpers to index a sparse or dense array
291-
function _getindex(arg::Union{CuSparseDeviceMatrixCSR,CuSparseDeviceMatrixCSC}, I, ptr)
291+
@inline function _getindex(arg::Union{CuSparseDeviceMatrixCSR{Tv},
292+
CuSparseDeviceMatrixCSC{Tv},
293+
CuSparseDeviceVector{Tv}}, I, ptr)::Tv where {Tv}
292294
if ptr == 0
293-
zero(eltype(arg))
295+
return zero(Tv)
294296
else
295-
@inbounds arg.nzVal[ptr]
297+
return @inbounds arg.nzVal[ptr]::Tv
296298
end
297299
end
298-
_getindex(arg, I, ptr) = Broadcast._broadcast_getindex(arg, I)
299300

301+
@inline function _getindex(arg::CuDeviceArray{Tv}, I, ptr)::Tv where {Tv}
302+
return @inbounds arg[I]::Tv
303+
end
304+
@inline _getindex(arg, I, ptr) = Broadcast._broadcast_getindex(arg, I)
300305

301306
## sparse broadcast implementation
302307

@@ -305,8 +310,46 @@ iter_type(::Type{<:CuSparseMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti}
305310
iter_type(::Type{<:CuSparseDeviceMatrixCSC}, ::Type{Ti}) where {Ti} = CSCIterator{Ti}
306311
iter_type(::Type{<:CuSparseDeviceMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti}
307312

313+
_has_row(A, offsets, row::Int32, fpreszeros::Bool) = fpreszeros ? 0i32 : row
314+
_has_row(A::CuDeviceArray, offsets, row::Int32, ::Bool) = row
315+
function _has_row(A::CuSparseDeviceVector, offsets, row::Int32, ::Bool)::Int32
316+
for row_ix in 1i32:length(A.iPtr)
317+
arg_row = @inbounds A.iPtr[row_ix]
318+
arg_row == row && return row_ix
319+
arg_row > row && break
320+
end
321+
return 0i32
322+
end
323+
324+
function _get_my_row(first_row)::Int32
325+
row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
326+
return row_ix + first_row - 1i32
327+
end
328+
329+
function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_row::Ti,
330+
fpreszeros::Bool, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
331+
args...) where {Ti, N}
332+
row = _get_my_row(first_row)
333+
row > last_row && return
334+
335+
# TODO load arg.iPtr slices into shared memory
336+
row_is_nnz = 0i32
337+
arg_row_is_nnz = ntuple(Val(N)) do i
338+
arg = @inbounds args[i]
339+
_has_row(arg, offsets, row, fpreszeros)::Int32
340+
end
341+
row_is_nnz = 0i32
342+
for i in 1:N
343+
row_is_nnz |= @inbounds arg_row_is_nnz[i]
344+
end
345+
key = (row_is_nnz == 0i32) ? typemax(Ti) : row
346+
@inbounds offsets[row - first_row + 1i32] = key => arg_row_is_nnz
347+
return
348+
end
349+
308350
# kernel to count the number of non-zeros in a row, to determine the row offsets
309-
function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatrixCSC}}, offsets::AbstractVector{Ti},
351+
function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatrixCSC}},
352+
offsets::AbstractVector{Ti},
310353
args...) where Ti
311354
# every thread processes an entire row
312355
leading_dim = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
@@ -331,8 +374,30 @@ function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatri
331374
return
332375
end
333376

334-
# broadcast kernels that iterate the elements of sparse arrays
335-
function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{AbstractVector,Nothing}, args...) where {Ti, T<:Union{CuSparseDeviceMatrixCSR{<:Any,Ti},CuSparseDeviceMatrixCSC{<:Any,Ti}}}
377+
function sparse_to_sparse_broadcast_kernel(f::F, output::CuSparseDeviceVector{Tv,Ti},
378+
offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
379+
args...) where {Tv, Ti, N, F}
380+
row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
381+
row_ix > output.nnz && return
382+
row_and_ptrs = @inbounds offsets[row_ix]
383+
row = @inbounds row_and_ptrs[1]
384+
arg_ptrs = @inbounds row_and_ptrs[2]
385+
vals = ntuple(Val(N)) do i
386+
arg = @inbounds args[i]
387+
# ptr is 0 if the sparse vector doesn't have an element at this row
388+
# ptr is 0 if the arg is a scalar AND f preserves zeros
389+
ptr = @inbounds arg_ptrs[i]
390+
_getindex(arg, row, ptr)::Tv
391+
end
392+
output_val = f(vals...)
393+
@inbounds output.iPtr[row_ix] = row
394+
@inbounds output.nzVal[row_ix] = output_val
395+
return
396+
end
397+
398+
function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{AbstractVector,Nothing},
399+
args...) where {Ti, T<:Union{CuSparseDeviceMatrixCSR{<:Any,Ti},
400+
CuSparseDeviceMatrixCSC{<:Any,Ti}}}
336401
# every thread processes an entire row
337402
leading_dim = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
338403
leading_dim_size = output isa CuSparseDeviceMatrixCSR ? size(output, 1) : size(output, 2)
@@ -345,7 +410,7 @@ function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{Abstract
345410
# fetch the row offset, and write it to the output
346411
@inbounds begin
347412
output_ptr = output_ptrs[leading_dim] = offsets[leading_dim]
348-
if leading_dim == leading_dim_size
413+
if leading_dim == leading_dim_size
349414
output_ptrs[leading_dim+1i32] = offsets[leading_dim+1i32]
350415
end
351416
end
@@ -368,7 +433,8 @@ function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{Abstract
368433

369434
return
370435
end
371-
function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv, Ti}, CuSparseMatrixCSC{Tv, Ti}}}, f,
436+
function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv, Ti},
437+
CuSparseMatrixCSC{Tv, Ti}}}, f,
372438
output::CuDeviceArray, args...) where {Tv, Ti}
373439
# every thread processes an entire row
374440
leading_dim = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
@@ -392,6 +458,28 @@ function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv,
392458

393459
return
394460
end
461+
462+
function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseVector}, f::F,
463+
output::CuDeviceArray{Tv},
464+
offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
465+
args...) where {Tv, F, N, Ti}
466+
# every thread processes an entire row
467+
row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
468+
row_ix > length(output) && return
469+
row_and_ptrs = @inbounds offsets[row_ix]
470+
row = @inbounds row_and_ptrs[1]
471+
arg_ptrs = @inbounds row_and_ptrs[2]
472+
vals = ntuple(Val(length(args))) do i
473+
arg = @inbounds args[i]
474+
# ptr is 0 if the sparse vector doesn't have an element at this row
475+
# ptr is row if the arg is dense OR a scalar with non-zero-preserving f
476+
# ptr is 0 if the arg is a scalar AND f preserves zeros
477+
ptr = @inbounds arg_ptrs[i]
478+
_getindex(arg, row, ptr)::Tv
479+
end
480+
@inbounds output[row] = f(vals...)
481+
return
482+
end
395483
## COV_EXCL_STOP
396484

397485
function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyle}})
@@ -405,12 +493,14 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
405493
error("broadcast with multiple types of sparse arrays ($(join(sparse_types, ", "))) is not supported")
406494
end
407495
sparse_typ = typeof(bc.args[first(sparse_args)])
408-
sparse_typ <: Union{CuSparseMatrixCSR,CuSparseMatrixCSC} ||
409-
error("broadcast with sparse arrays is currently only implemented for CSR and CSC matrices")
496+
sparse_typ <: Union{CuSparseMatrixCSR,CuSparseMatrixCSC,CuSparseVector} ||
497+
error("broadcast with sparse arrays is currently only implemented for vectors and CSR and CSC matrices")
410498
Ti = if sparse_typ <: CuSparseMatrixCSR
411499
reduce(promote_type, map(i->eltype(bc.args[i].rowPtr), sparse_args))
412500
elseif sparse_typ <: CuSparseMatrixCSC
413501
reduce(promote_type, map(i->eltype(bc.args[i].colPtr), sparse_args))
502+
elseif sparse_typ <: CuSparseVector
503+
reduce(promote_type, map(i->eltype(bc.args[i].iPtr), sparse_args))
414504
end
415505

416506
# determine the output type
@@ -433,23 +523,32 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
433523

434524
# the kernels below parallelize across rows or cols, not elements, so it's unlikely
435525
# we'll launch many threads. to maximize utilization, parallelize across blocks first.
436-
rows, cols = size(bc)
526+
rows, cols = get(size(bc), 1, 1), get(size(bc), 2, 1) # `size(bc, ::Int)` is missing
437527
function compute_launch_config(kernel)
438528
config = launch_configuration(kernel.fun)
439529
if sparse_typ <: CuSparseMatrixCSR
440530
threads = min(rows, config.threads)
441-
blocks = max(cld(rows, threads), config.blocks)
531+
blocks = max(cld(rows, threads), config.blocks)
442532
threads = cld(rows, blocks)
443533
elseif sparse_typ <: CuSparseMatrixCSC
444534
threads = min(cols, config.threads)
445-
blocks = max(cld(cols, threads), config.blocks)
535+
blocks = max(cld(cols, threads), config.blocks)
446536
threads = cld(cols, blocks)
537+
elseif sparse_typ <: CuSparseVector
538+
threads = 512
539+
blocks = max(cld(rows, threads), config.blocks)
447540
end
448541
(; threads, blocks)
449542
end
450-
543+
# for CuSparseVec, figure out the actual row range we need to address, e.g. if m = 2^20
544+
# but the only rows present in any sparse vector input are between 2 and 128, no need to
545+
# launch massive threads.
546+
# TODO: use the difference here to set the thread count
547+
overall_first_row = one(Ti)
548+
overall_last_row = Ti(rows)
549+
offsets = nothing
451550
# allocate the output container
452-
if !fpreszeros
551+
if !fpreszeros && sparse_typ <: Union{CuSparseMatrixCSR, CuSparseMatrixCSC}
453552
# either we have dense inputs, or the function isn't preserving zeros,
454553
# so use a dense output to broadcast into.
455554
output = CuArray{Tv}(undef, size(bc))
@@ -466,20 +565,20 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
466565
end
467566
end
468567
broadcast!(bc.f, output, nonsparse_args...)
469-
elseif length(sparse_args) == 1
568+
elseif length(sparse_args) == 1 && sparse_typ <: Union{CuSparseMatrixCSR, CuSparseMatrixCSC}
470569
# we only have a single sparse input, so we can reuse its structure for the output.
471570
# this avoids a kernel launch and costly synchronization.
472571
sparse_arg = bc.args[first(sparse_args)]
473572
if sparse_typ <: CuSparseMatrixCSR
474573
offsets = rowPtr = sparse_arg.rowPtr
475-
colVal = similar(sparse_arg.colVal)
476-
nzVal = similar(sparse_arg.nzVal, Tv)
477-
output = CuSparseMatrixCSR(rowPtr, colVal, nzVal, size(bc))
574+
colVal = similar(sparse_arg.colVal)
575+
nzVal = similar(sparse_arg.nzVal, Tv)
576+
output = CuSparseMatrixCSR(rowPtr, colVal, nzVal, size(bc))
478577
elseif sparse_typ <: CuSparseMatrixCSC
479578
offsets = colPtr = sparse_arg.colPtr
480-
rowVal = similar(sparse_arg.rowVal)
481-
nzVal = similar(sparse_arg.nzVal, Tv)
482-
output = CuSparseMatrixCSC(colPtr, rowVal, nzVal, size(bc))
579+
rowVal = similar(sparse_arg.rowVal)
580+
nzVal = similar(sparse_arg.nzVal, Tv)
581+
output = CuSparseMatrixCSC(colPtr, rowVal, nzVal, size(bc))
483582
end
484583
# NOTE: we don't use CUSPARSE's similar, because that copies the structure arrays,
485584
# while we do that in our kernel (for consistency with other code paths)
@@ -490,43 +589,79 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
490589
CuArray{Ti}(undef, rows+1)
491590
elseif sparse_typ <: CuSparseMatrixCSC
492591
CuArray{Ti}(undef, cols+1)
592+
elseif sparse_typ <: CuSparseVector
593+
CUDA.@allowscalar begin
594+
arg_first_rows = ntuple(Val(length(bc.args))) do i
595+
bc.args[i] isa CuSparseVector && return bc.args[i].iPtr[1]
596+
return one(Ti)
597+
end
598+
arg_last_rows = ntuple(Val(length(bc.args))) do i
599+
bc.args[i] isa CuSparseVector && return bc.args[i].iPtr[end]
600+
return Ti(rows)
601+
end
602+
end
603+
overall_first_row = min(arg_first_rows...)
604+
overall_last_row = max(arg_last_rows...)
605+
CuVector{Pair{Ti, NTuple{length(bc.args), Ti}}}(undef, overall_last_row - overall_first_row + 1)
493606
end
494607
let
495-
args = (sparse_typ, offsets, bc.args...)
608+
args = if sparse_typ <: CuSparseVector
609+
(sparse_typ, overall_first_row, overall_last_row, fpreszeros, offsets, bc.args...)
610+
else
611+
(sparse_typ, offsets, bc.args...)
612+
end
496613
kernel = @cuda launch=false compute_offsets_kernel(args...)
497614
threads, blocks = compute_launch_config(kernel)
498615
kernel(args...; threads, blocks)
499616
end
500-
501617
# accumulate these values so that we can use them directly as row pointer offsets,
502618
# as well as to get the total nnz count to allocate the sparse output array.
503619
# cusparseXcsrgeam2Nnz computes this in one go, but it doesn't seem worth the effort
504-
accumulate!(Base.add_sum, offsets, offsets)
505-
total_nnz = @allowscalar last(offsets[end]) - 1
506-
620+
if !(sparse_typ <: CuSparseVector)
621+
accumulate!(Base.add_sum, offsets, offsets)
622+
total_nnz = @allowscalar last(offsets[end]) - 1
623+
else
624+
sort!(offsets; by=first)
625+
total_nnz = mapreduce(x->first(x) != typemax(first(x)), +, offsets)
626+
end
507627
output = if sparse_typ <: CuSparseMatrixCSR
508628
colVal = CuArray{Ti}(undef, total_nnz)
509-
nzVal = CuArray{Tv}(undef, total_nnz)
629+
nzVal = CuArray{Tv}(undef, total_nnz)
510630
CuSparseMatrixCSR(offsets, colVal, nzVal, size(bc))
511631
elseif sparse_typ <: CuSparseMatrixCSC
512632
rowVal = CuArray{Ti}(undef, total_nnz)
513-
nzVal = CuArray{Tv}(undef, total_nnz)
633+
nzVal = CuArray{Tv}(undef, total_nnz)
514634
CuSparseMatrixCSC(offsets, rowVal, nzVal, size(bc))
635+
elseif sparse_typ <: CuSparseVector && !fpreszeros
636+
CuArray{Tv}(undef, size(bc))
637+
elseif sparse_typ <: CuSparseVector && fpreszeros
638+
iPtr = CUDA.zeros(Ti, total_nnz)
639+
nzVal = CUDA.zeros(Tv, total_nnz)
640+
CuSparseVector(iPtr, nzVal, rows)
641+
end
642+
if sparse_typ <: CuSparseVector && !fpreszeros
643+
nonsparse_args = map(bc.args) do arg
644+
# NOTE: this assumes the broadcst is flattened, but not yet preprocessed
645+
if arg isa AbstractCuSparseArray
646+
zero(eltype(arg))
647+
else
648+
arg
649+
end
650+
end
651+
broadcast!(bc.f, output, nonsparse_args...)
515652
end
516653
end
517-
518654
# perform the actual broadcast
519655
if output isa AbstractCuSparseArray
520-
args = (bc.f, output, offsets, bc.args...)
656+
args = (bc.f, output, offsets, bc.args...)
521657
kernel = @cuda launch=false sparse_to_sparse_broadcast_kernel(args...)
522-
threads, blocks = compute_launch_config(kernel)
523-
kernel(args...; threads, blocks)
524658
else
525-
args = (sparse_typ, bc.f, output, bc.args...)
659+
args = sparse_typ <: CuSparseVector ? (sparse_typ, bc.f, output, offsets, bc.args...) :
660+
(sparse_typ, bc.f, output, bc.args...)
526661
kernel = @cuda launch=false sparse_to_dense_broadcast_kernel(args...)
527-
threads, blocks = compute_launch_config(kernel)
528-
kernel(args...; threads, blocks)
529662
end
663+
threads, blocks = compute_launch_config(kernel)
664+
kernel(args...; threads, blocks)
530665

531666
return output
532667
end

lib/cusparse/device.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ struct CuSparseDeviceVector{Tv,Ti, A} <: AbstractSparseVector{Tv,Ti}
1919
nnz::Ti
2020
end
2121

22-
Base.length(g::CuSparseDeviceVector) = prod(g.dims)
22+
Base.length(g::CuSparseDeviceVector) = g.len
2323
Base.size(g::CuSparseDeviceVector) = (g.len,)
2424
SparseArrays.nnz(g::CuSparseDeviceVector) = g.nnz
2525

0 commit comments

Comments
 (0)