Skip to content

Commit f7deec6

Browse files
authored
Add SparseArrays functionality for CuSparseDeviceColumnView (#2904)
1 parent e17d626 commit f7deec6

File tree

2 files changed

+103
-1
lines changed

2 files changed

+103
-1
lines changed

lib/cusparse/device.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Base.length(g::CuSparseDeviceVector) = g.len
2323
Base.size(g::CuSparseDeviceVector) = (g.len,)
2424
SparseArrays.nnz(g::CuSparseDeviceVector) = g.nnz
2525

26-
struct CuSparseDeviceMatrixCSC{Tv,Ti,A} <: AbstractSparseMatrix{Tv,Ti}
26+
struct CuSparseDeviceMatrixCSC{Tv,Ti,A} <: SparseArrays.AbstractSparseMatrixCSC{Tv,Ti}
2727
colPtr::CuDeviceVector{Ti, A}
2828
rowVal::CuDeviceVector{Ti, A}
2929
nzVal::CuDeviceVector{Tv, A}
@@ -38,6 +38,29 @@ SparseArrays.rowvals(g::CuSparseDeviceMatrixCSC) = g.rowVal
3838
SparseArrays.getcolptr(g::CuSparseDeviceMatrixCSC) = g.colPtr
3939
SparseArrays.getnzval(g::CuSparseDeviceMatrixCSC) = g.nzVal
4040
SparseArrays.nzrange(g::CuSparseDeviceMatrixCSC, col::Integer) = SparseArrays.getcolptr(g)[col]:(SparseArrays.getcolptr(g)[col+1]-1)
41+
SparseArrays.nonzeros(g::CuSparseDeviceMatrixCSC) = g.nzVal
42+
43+
const CuSparseDeviceColumnView{Tv, Ti} = SubArray{Tv, 1, <:CuSparseDeviceMatrixCSC{Tv, Ti}, Tuple{Base.Slice{Base.OneTo{Int}}, Int}}
44+
function SparseArrays.nonzeros(x::CuSparseDeviceColumnView)
45+
rowidx, colidx = parentindices(x)
46+
A = parent(x)
47+
@inbounds y = view(SparseArrays.nonzeros(A), SparseArrays.nzrange(A, colidx))
48+
return y
49+
end
50+
51+
function SparseArrays.nonzeroinds(x::CuSparseDeviceColumnView)
52+
rowidx, colidx = parentindices(x)
53+
A = parent(x)
54+
@inbounds y = view(SparseArrays.rowvals(A), SparseArrays.nzrange(A, colidx))
55+
return y
56+
end
57+
SparseArrays.rowvals(x::CuSparseDeviceColumnView) = SparseArrays.nonzeroinds(x)
58+
59+
function SparseArrays.nnz(x::CuSparseDeviceColumnView)
60+
rowidx, colidx = parentindices(x)
61+
A = parent(x)
62+
return length(SparseArrays.nzrange(A, colidx))
63+
end
4164

4265
struct CuSparseDeviceMatrixCSR{Tv,Ti,A} <: AbstractSparseMatrix{Tv,Ti}
4366
rowPtr::CuDeviceVector{Ti, A}

test/libraries/cusparse/device.jl

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,82 @@ using CUDA.CUSPARSE: CuSparseDeviceVector, CuSparseDeviceMatrixCSC, CuSparseDevi
2727
cuA = CuSparseMatrixBSR(A, 2)
2828
@test cudaconvert(cuA) isa CuSparseDeviceMatrixBSR{Float64, Cint, AS.Global}
2929
end
30+
31+
@testset "device SparseArrays api" begin
32+
@testset "nnz per column" begin
33+
function nnz_per_column(A::CuSparseMatrixCSC{Tv, Ti}) where {Tv, Ti}
34+
function nnz_per_column_kernel(out, A)
35+
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
36+
col = @view A[:, i]
37+
out[i] = SparseArrays.nnz(col)
38+
nothing
39+
end
40+
41+
out = CuVector{Ti}(undef, size(A, 2))
42+
@cuda threads=size(A, 2) nnz_per_column_kernel(out, A)
43+
out
44+
end
45+
46+
nnz_per_column(A::SparseMatrixCSC) = map(SparseArrays.nnz, eachcol(A))
47+
48+
A = sprand(10, 10, 0.5)
49+
cuA = CuSparseMatrixCSC(A)
50+
51+
@test nnz_per_column(A) == Vector(nnz_per_column(cuA))
52+
end
53+
54+
@testset "sum per column" begin
55+
function sum_per_column(A::CuSparseMatrixCSC{Tv, Ti}) where {Tv, Ti}
56+
function sum_per_column_kernel(out, A)
57+
j = blockIdx().x
58+
col = @view A[:, j]
59+
60+
v = zero(Tv)
61+
i = threadIdx().x
62+
while i <= SparseArrays.nnz(col)
63+
v += nonzeros(col)[i]
64+
i += blockDim().x
65+
end
66+
v = CUDA.reduce_warp(+, v)
67+
68+
if threadIdx().x == 1
69+
out[j] = v
70+
end
71+
nothing
72+
end
73+
74+
out = CuVector{Tv}(undef, size(A, 2))
75+
@cuda threads=32 blocks=size(A, 2) sum_per_column_kernel(out, A)
76+
out
77+
end
78+
79+
sum_per_column(A::SparseMatrixCSC) = vec(sum(A; dims=1))
80+
81+
A = sprand(10, 10, 0.5)
82+
cuA = CuSparseMatrixCSC(A)
83+
84+
@test sum_per_column(A) Vector(sum_per_column(cuA))
85+
end
86+
87+
@testset "last nonzero per column" begin
88+
function last_nz_per_column(A::CuSparseMatrixCSC{Tv, Ti}) where {Tv, Ti}
89+
function last_nz_per_column_kernel(out, A)
90+
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
91+
col = @view A[:, i]
92+
out[i] = last(SparseArrays.rowvals(col))
93+
nothing
94+
end
95+
96+
out = CuVector{Ti}(undef, size(A, 2))
97+
@cuda threads=size(A, 2) last_nz_per_column_kernel(out, A)
98+
out
99+
end
100+
101+
last_nz_per_column(A::SparseMatrixCSC) = map(last SparseArrays.rowvals, eachcol(A))
102+
103+
A = sprand(10, 10, 0.5)
104+
cuA = CuSparseMatrixCSC(A)
105+
106+
@test last_nz_per_column(A) == Vector(last_nz_per_column(cuA))
107+
end
108+
end

0 commit comments

Comments
 (0)