Skip to content

Commit 132deaf

Browse files
authored
[rocSPARSE] Interface additional sparse routines (#707)
* [rocSPARSE] Interface additional sparse routines * Interface sparsetodense and densetosparse * Add more constructors for ROCSparseMatrixDescriptor * Fix the errors
1 parent fadefca commit 132deaf

File tree

5 files changed

+262
-16
lines changed

5 files changed

+262
-16
lines changed

src/sparse/conversions.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,15 @@ for (elty, felty) in ((:Int32, :Float32), (:Int64, :Float64), (:Int128, :Complex
201201
end
202202
end
203203

204+
## ROCSparseVector to ROCVector
205+
ROCVector(x::ROCSparseVector{T}) where {T} = ROCVector{T}(x)
206+
207+
function ROCVector{T}(sv::ROCSparseVector{T}) where {T}
208+
n = length(sv)
209+
dv = AMDGPU.zeros(T, n)
210+
scatter!(dv, sv, 'O')
211+
end
212+
204213
## CSR to BSR and vice-versa
205214

206215
for (fname,elty) in ((:rocsparse_scsr2bsr, :Float32),
@@ -400,7 +409,7 @@ for (elty, welty) in ((:Float16, :Float32), (:ComplexF16, :ComplexF32))
400409
end
401410
end
402411

403-
function Base.copyto!(dest::Array{T, 2}, src::AbstractROCSparseMatrix{T}) where T
412+
function Base.copyto!(dest::Matrix{T}, src::AbstractROCSparseMatrix{T}) where T
404413
copyto!(dest, ROCMatrix{T}(src))
405414
end
406415

src/sparse/generic.jl

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,128 @@
22

33
## API functions
44

5+
function sparsetodense(A::Union{ROCSparseMatrixCSC{T},ROCSparseMatrixCSR{T},ROCSparseMatrixCOO{T}}, index::SparseChar,
6+
algo::rocsparse_sparse_to_dense_alg=rocsparse_sparse_to_dense_alg_default) where {T}
7+
m,n = size(A)
8+
B = ROCMatrix{T}(undef, m, n)
9+
desc_sparse = ROCSparseMatrixDescriptor(A, index)
10+
desc_dense = ROCDenseMatrixDescriptor(B)
11+
12+
function bufferSize()
13+
out = Ref{Csize_t}()
14+
rocsparse_sparse_to_dense(handle(), desc_sparse, desc_dense, algo, out, C_NULL)
15+
return out[]
16+
end
17+
18+
buffer_size = Ref{Csize_t}()
19+
with_workspace(bufferSize) do buffer
20+
buffer_size[] = sizeof(buffer)
21+
rocsparse_sparse_to_dense(handle(), desc_sparse, desc_dense, algo, buffer_size, buffer)
22+
end
23+
return B
24+
end
25+
26+
function densetosparse(A::ROCMatrix{T}, fmt::Symbol, index::SparseChar,
27+
algo::rocsparse_dense_to_sparse_alg=rocsparse_dense_to_sparse_alg_default) where {T}
28+
m,n = size(A)
29+
local rowPtr, colPtr, desc_sparse, B
30+
if fmt == :coo
31+
desc_sparse = ROCSparseMatrixDescriptor(ROCSparseMatrixCOO, T, Cint, m, n, index)
32+
elseif fmt == :csr
33+
rowPtr = ROCVector{Cint}(undef, m+1)
34+
desc_sparse = ROCSparseMatrixDescriptor(ROCSparseMatrixCSR, rowPtr, T, Cint, m, n, index)
35+
elseif fmt == :csc
36+
colPtr = ROCVector{Cint}(undef, n+1)
37+
desc_sparse = ROCSparseMatrixDescriptor(ROCSparseMatrixCSC, colPtr, T, Cint, m, n, index)
38+
else
39+
error("Format :$fmt not available, use :csc, :csr or :coo.")
40+
end
41+
desc_dense = ROCDenseMatrixDescriptor(A)
42+
43+
function bufferSize()
44+
out = Ref{Csize_t}()
45+
rocsparse_dense_to_sparse(handle(), desc_dense, desc_sparse, algo, out, C_NULL)
46+
return out[]
47+
end
48+
49+
buffer_size = Ref{Csize_t}()
50+
with_workspace(bufferSize) do buffer
51+
buffer_size[] = sizeof(buffer)
52+
# Analysis
53+
rocsparse_dense_to_sparse(handle(), desc_dense, desc_sparse, algo, C_NULL, buffer)
54+
nnzB = Ref{Int64}()
55+
rocsparse_spmat_get_size(desc_sparse, Ref{Int64}(), Ref{Int64}(), nnzB)
56+
if fmt == :coo
57+
rowInd = ROCVector{Cint}(undef, nnzB[])
58+
colInd = ROCVector{Cint}(undef, nnzB[])
59+
nzVal = ROCVector{T}(undef, nnzB[])
60+
B = ROCSparseMatrixCOO{T, Cint}(rowInd, colInd, nzVal, (m,n))
61+
rocsparse_coo_set_pointers(desc_sparse, B.rowInd, B.colInd, B.nzVal)
62+
elseif fmt == :csr
63+
colVal = ROCVector{Cint}(undef, nnzB[])
64+
nzVal = ROCVector{T}(undef, nnzB[])
65+
B = ROCSparseMatrixCSR{T, Cint}(rowPtr, colVal, nzVal, (m,n))
66+
rocsparse_csr_set_pointers(desc_sparse, B.rowPtr, B.colVal, B.nzVal)
67+
elseif fmt == :csc
68+
rowVal = ROCVector{Cint}(undef, nnzB[])
69+
nzVal = ROCVector{T}(undef, nnzB[])
70+
B = ROCSparseMatrixCSC{T, Cint}(colPtr, rowVal, nzVal, (m,n))
71+
rocsparse_csc_set_pointers(desc_sparse, B.colPtr, B.rowVal, B.nzVal)
72+
else
73+
error("Format :$fmt not available, use :csc, :csr or :coo.")
74+
end
75+
rocsparse_dense_to_sparse(handle(), desc_dense, desc_sparse, algo, buffer_size, buffer)
76+
end
77+
return B
78+
end
79+
580
function gather!(X::ROCSparseVector, Y::ROCVector, index::SparseChar)
681
descX = ROCSparseVectorDescriptor(X, index)
782
descY = ROCDenseVectorDescriptor(Y)
883
rocsparse_gather(handle(), descY, descX)
984
X
1085
end
1186

87+
function scatter!(Y::ROCVector, X::ROCSparseVector, index::SparseChar)
88+
descX = ROCSparseVectorDescriptor(X, index)
89+
descY = ROCDenseVectorDescriptor(Y)
90+
rocsparse_scatter(handle(), descX, descY)
91+
return Y
92+
end
93+
94+
function axpby!(alpha::Number, X::ROCSparseVector{T}, beta::Number, Y::ROCVector{T}, index::SparseChar) where {T}
95+
descX = ROCSparseVectorDescriptor(X, index)
96+
descY = ROCDenseVectorDescriptor(Y)
97+
rocsparse_axpby(handle(), Ref{T}(alpha), descX, Ref{T}(beta), descY)
98+
return Y
99+
end
100+
101+
function rot!(X::ROCSparseVector{T}, Y::ROCVector{T}, c::Number, s::Number, index::SparseChar) where {T}
102+
descX = ROCSparseVectorDescriptor(X, index)
103+
descY = ROCDenseVectorDescriptor(Y)
104+
rocsparse_rot(handle(), Ref{T}(c), Ref{T}(s), descX, descY)
105+
return X, Y
106+
end
107+
108+
function vv!(transx::SparseChar, X::ROCSparseVector{T}, Y::DenseROCVector{T}, index::SparseChar) where {T}
109+
descX = ROCSparseVectorDescriptor(X, index)
110+
descY = ROCDenseVectorDescriptor(Y)
111+
result = Ref{T}()
112+
113+
function bufferSize()
114+
out = Ref{Csize_t}()
115+
rocsparse_spvv(handle(), transx, descX, descY, result, T, out, C_NULL)
116+
return out[]
117+
end
118+
119+
buffer_size = Ref{Csize_t}()
120+
with_workspace(bufferSize) do buffer
121+
buffer_size[] = sizeof(buffer)
122+
rocsparse_spvv(handle(), transx, descX, descY, result, T, buffer_size, buffer)
123+
end
124+
return result[]
125+
end
126+
12127
function mv!(
13128
transa::SparseChar, alpha::Number, A::Union{ROCSparseMatrixCSR{T}, ROCSparseMatrixCSC{T}, ROCSparseMatrixCOO{T}},
14129
X::DenseROCVector{T}, beta::Number, Y::DenseROCVector{T}, index::SparseChar,

src/sparse/helpers.jl

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,33 @@ Base.unsafe_convert(::Type{rocsparse_dnmat_descr}, desc::ROCDenseMatrixDescripto
104104
mutable struct ROCSparseMatrixDescriptor
105105
handle::rocsparse_spmat_descr
106106

107+
function ROCSparseMatrixDescriptor(A::ROCSparseMatrixCOO, IndexBase::Char; transposed::Bool=false)
108+
desc_ref = Ref{rocsparse_spmat_descr}()
109+
if transposed
110+
rocsparse_create_coo_descr(
111+
desc_ref, reverse(size(A))..., nnz(A),
112+
A.colInd, A.rowInd, nonzeros(A),
113+
eltype(A.colInd), IndexBase, eltype(nonzeros(A))
114+
)
115+
else
116+
rocsparse_create_coo_descr(
117+
desc_ref, size(A)..., nnz(A),
118+
A.rowInd, A.colInd, nonzeros(A),
119+
eltype(A.rowInd), IndexBase, eltype(nonzeros(A))
120+
)
121+
end
122+
obj = new(desc_ref[])
123+
return finalizer(rocsparse_destroy_spmat_descr, obj)
124+
end
125+
126+
function ROCSparseMatrixDescriptor(::Type{ROCSparseMatrixCOO}, Tv::DataType, Ti::DataType, m::Integer, n::Integer, IndexBase::Char)
127+
desc_ref = Ref{rocsparse_spmat_descr}()
128+
rocsparse_create_coo_descr(desc_ref, m, n, Ti(0), C_NULL, C_NULL, C_NULL, Ti, IndexBase, Tv)
129+
obj = new(desc_ref[])
130+
finalizer(rocsparse_destroy_spmat_descr, obj)
131+
return obj
132+
end
133+
107134
function ROCSparseMatrixDescriptor(A::ROCSparseMatrixCSR, IndexBase::Char; transposed::Bool=false)
108135
desc_ref = Ref{rocsparse_spmat_descr}()
109136
if transposed
@@ -121,6 +148,14 @@ mutable struct ROCSparseMatrixDescriptor
121148
return finalizer(rocsparse_destroy_spmat_descr, obj)
122149
end
123150

151+
function ROCSparseMatrixDescriptor(::Type{ROCSparseMatrixCSR}, rowPtr::ROCVector, Tv::DataType, Ti::DataType, m::Integer, n::Integer, IndexBase::Char)
152+
desc_ref = Ref{rocsparse_spmat_descr}()
153+
rocsparse_create_csr_descr(desc_ref, m, n, Ti(0), rowPtr, C_NULL, C_NULL, Ti, Ti, IndexBase, Tv)
154+
obj = new(desc_ref[])
155+
finalizer(rocsparse_destroy_spmat_descr, obj)
156+
return obj
157+
end
158+
124159
function ROCSparseMatrixDescriptor(A::ROCSparseMatrixCSC, IndexBase::Char; transposed::Bool=false)
125160
desc_ref = Ref{rocsparse_spmat_descr}()
126161
if transposed
@@ -138,23 +173,12 @@ mutable struct ROCSparseMatrixDescriptor
138173
return finalizer(rocsparse_destroy_spmat_descr, obj)
139174
end
140175

141-
function ROCSparseMatrixDescriptor(A::ROCSparseMatrixCOO, IndexBase::Char; transposed::Bool=false)
176+
function ROCSparseMatrixDescriptor(::Type{ROCSparseMatrixCSC}, colPtr::ROCVector, Tv::DataType, Ti::DataType, m::Integer, n::Integer, IndexBase::Char)
142177
desc_ref = Ref{rocsparse_spmat_descr}()
143-
if transposed
144-
rocsparse_create_coo_descr(
145-
desc_ref, reverse(size(A))..., nnz(A),
146-
A.colInd, A.rowInd, nonzeros(A),
147-
eltype(A.colInd), IndexBase, eltype(nonzeros(A))
148-
)
149-
else
150-
rocsparse_create_coo_descr(
151-
desc_ref, size(A)..., nnz(A),
152-
A.rowInd, A.colInd, nonzeros(A),
153-
eltype(A.rowInd), IndexBase, eltype(nonzeros(A))
154-
)
155-
end
178+
rocsparse_create_csc_descr(desc_ref, m, n, Ti(0), colPtr, C_NULL, C_NULL, Ti, Ti, IndexBase, Tv)
156179
obj = new(desc_ref[])
157-
return finalizer(rocsparse_destroy_spmat_descr, obj)
180+
finalizer(rocsparse_destroy_spmat_descr, obj)
181+
return obj
158182
end
159183
end
160184

src/sparse/interfaces.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ function mm_wrapper(
2222
mm!(transa, transb, alpha, A, B, beta, C, 'O')
2323
end
2424

25+
LinearAlgebra.dot(x::ROCSparseVector{T}, y::DenseROCVector{T}) where {T <: BlasReal} = vv!('N', x, y, 'O')
26+
LinearAlgebra.dot(x::DenseROCVector{T}, y::ROCSparseVector{T}) where {T <: BlasReal} = dot(y, x)
27+
28+
LinearAlgebra.dot(x::ROCSparseVector{T}, y::DenseROCVector{T}) where {T <: BlasComplex} = vv!('C', x, y, 'O')
29+
LinearAlgebra.dot(x::DenseROCVector{T}, y::ROCSparseVector{T}) where {T <: BlasComplex} = conj(dot(y,x))
30+
2531
# legacy methods with final MulAddMul argument
2632
LinearAlgebra.generic_matvecmul!(C::ROCVector{T}, tA::AbstractChar, A::ROCSparseMatrix{T}, B::DenseROCVector{T}, _add::MulAddMul) where T <: BlasFloat =
2733
LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)

test/rocsparse/generic.jl

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,95 @@
1+
fmt = Dict(ROCSparseMatrixCSC => :csc,
2+
ROCSparseMatrixCSR => :csr,
3+
ROCSparseMatrixCOO => :coo)
4+
5+
for SparseMatrixType in [ROCSparseMatrixCSC, ROCSparseMatrixCSR, ROCSparseMatrixCOO]
6+
@testset "$SparseMatrixType -- densetosparse algo=$algo" for algo in [rocSPARSE.rocsparse_dense_to_sparse_alg_default]
7+
@testset "densetosparse $T" for T in [Float32, Float64, ComplexF32, ComplexF64]
8+
A_sparse = sprand(T, 10, 20, 0.5)
9+
A_dense = Matrix{T}(A_sparse)
10+
dA_dense = ROCMatrix{T}(A_dense)
11+
dA_sparse = rocSPARSE.densetosparse(dA_dense, fmt[SparseMatrixType], 'O', algo)
12+
@test A_sparse collect(dA_sparse)
13+
end
14+
end
15+
@testset "$SparseMatrixType -- sparsetodense algo=$algo" for algo in [rocSPARSE.rocsparse_sparse_to_dense_alg_default]
16+
@testset "sparsetodense $T" for T in [Float32, Float64, ComplexF32, ComplexF64]
17+
A_dense = rand(T, 10, 20)
18+
A_sparse = sparse(A_dense)
19+
dA_sparse = SparseMatrixType(A_sparse)
20+
dA_dense = rocSPARSE.sparsetodense(dA_sparse, 'O', algo)
21+
@test A_dense collect(dA_dense)
22+
end
23+
end
24+
end
25+
26+
@testset "gather! $T" for T in [Float32, Float64, ComplexF32, ComplexF64]
27+
X = sprand(T, 20, 0.5)
28+
dX = ROCSparseVector{T}(X)
29+
Y = rand(T, 20)
30+
dY = ROCVector{T}(Y)
31+
rocSPARSE.gather!(dX, dY, 'O')
32+
Z = copy(X)
33+
for i = 1:nnz(X)
34+
Z[X.nzind[i]] = Y[X.nzind[i]]
35+
end
36+
@test Z sparse(collect(dX))
37+
end
38+
39+
@testset "scatter! $T" for T in [Float32, Float64, ComplexF32, ComplexF64]
40+
X = sprand(T, 20, 0.5)
41+
dX = ROCSparseVector{T}(X)
42+
Y = rand(T, 20)
43+
dY = ROCVector{T}(Y)
44+
rocSPARSE.scatter!(dY, dX, 'O')
45+
Z = copy(Y)
46+
for i = 1:nnz(X)
47+
Z[X.nzind[i]] = X.nzval[i]
48+
end
49+
@test Z collect(dY)
50+
end
51+
52+
@testset "axpby! $T" for T in [Float32, Float64, ComplexF32, ComplexF64]
53+
X = sprand(T, 20, 0.5)
54+
dX = ROCSparseVector{T}(X)
55+
Y = rand(T, 20)
56+
dY = ROCVector{T}(Y)
57+
alpha = rand(T)
58+
beta = rand(T)
59+
rocSPARSE.axpby!(alpha, dX, beta, dY, 'O')
60+
@test alpha * X + beta * Y collect(dY)
61+
end
62+
63+
@testset "rot! $T" for T in [Float32, Float64, ComplexF32, ComplexF64]
64+
X = sprand(T, 20, 0.5)
65+
dX = ROCSparseVector{T}(X)
66+
Y = rand(T, 20)
67+
dY = ROCVector{T}(Y)
68+
c = rand(T)
69+
s = rand(T)
70+
rocSPARSE.rot!(dX, dY, c, s, 'O')
71+
W = copy(X)
72+
Z = copy(Y)
73+
for i = 1:nnz(X)
74+
W[X.nzind[i]] = c * X.nzval[i] + s * Y[X.nzind[i]]
75+
Z[X.nzind[i]] = -s * X.nzval[i] + c * Y[X.nzind[i]]
76+
end
77+
@test W collect(dX)
78+
@test Z collect(dY)
79+
end
80+
81+
@testset "vv! $T" for T in [Float32, Float64, ComplexF32, ComplexF64]
82+
for (transx, opx) in [('N', identity), ('C', conj)]
83+
T <: Real && transx == 'C' && continue
84+
X = sprand(T, 20, 0.5)
85+
dX = ROCSparseVector{T}(X)
86+
Y = rand(T, 20)
87+
dY = ROCVector{T}(Y)
88+
result = rocSPARSE.vv!(transx, dX, dY, 'O')
89+
@test sum(opx(X[i]) * Y[i] for i=1:20) result
90+
end
91+
end
92+
193
@testset "generic mv!" for T in (Float32, Float64, ComplexF32, ComplexF64)
294
A = sprand(T, 10, 10, 0.1)
395
x = rand(T, 10)

0 commit comments

Comments
 (0)