Skip to content

Commit 20012bb

Browse files
authored
Support more conversions (#816)
1 parent 979ae13 commit 20012bb

File tree

2 files changed

+64
-15
lines changed

2 files changed

+64
-15
lines changed

src/sparse/array.jl

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,8 @@ const ROCSparseMatrix{Tv, Ti} = Union{
154154

155155
const ROCSparseVecOrMat = Union{ROCSparseVector, ROCSparseMatrix}
156156

157-
# NOTE: we use Cint as default Ti on CUDA instead of Int to provide
158-
# maximum compatiblity to old CUSPARSE APIs
157+
# NOTE: we use Cint as default Ti on ROCm instead of Int to provide
158+
# maximum compatiblity to old ROCSparse APIs
159159
# The same pattern was followed for AMDGPU as well
160160
function ROCSparseVector{Tv}(iPtr::ROCVector{<:Integer}, nzVal::ROCVector, len::Integer) where Tv
161161
ROCSparseVector{Tv, Cint}(convert(ROCVector{Cint}, iPtr), nzVal, len)
@@ -284,6 +284,7 @@ SparseArrays.nnz(g::AbstractROCSparseArray) = g.nnz
284284
SparseArrays.nonzeros(g::AbstractROCSparseArray) = g.nzVal
285285

286286
SparseArrays.nonzeroinds(g::AbstractROCSparseVector) = g.iPtr
287+
SparseArrays.rowvals(g::AbstractROCSparseVector) = nonzeroinds(g)
287288

288289
SparseArrays.rowvals(g::ROCSparseMatrixCSC) = g.rowVal
289290
SparseArrays.getcolptr(g::ROCSparseMatrixCSC) = g.colPtr
@@ -422,14 +423,8 @@ ROCSparseMatrixCSC(x::Transpose{T}) where {T} = ROCSparseMatrixCSC{T}(x)
422423
ROCSparseMatrixCSC(x::Adjoint{T}) where {T} = ROCSparseMatrixCSC{T}(x)
423424

424425
# gpu to cpu
425-
function SparseVector(x::ROCSparseVector)
426-
SparseVector(length(x), Array(nonzeroinds(x)), Array(nonzeros(x)))
427-
end
428-
429-
function SparseMatrixCSC(x::ROCSparseMatrixCSC)
430-
SparseMatrixCSC(size(x)..., Array(x.colPtr), Array(rowvals(x)), Array(nonzeros(x)))
431-
end
432-
426+
SparseVector(x::ROCSparseVector) = SparseVector(length(x), Array(nonzeroinds(x)), Array(nonzeros(x)))
427+
SparseMatrixCSC(x::ROCSparseMatrixCSC) = SparseMatrixCSC(size(x)..., Array(x.colPtr), Array(rowvals(x)), Array(nonzeros(x)))
433428
SparseMatrixCSC(x::ROCSparseMatrixCSR) = SparseMatrixCSC(ROCSparseMatrixCSC(x)) # no direct conversion
434429
SparseMatrixCSC(x::ROCSparseMatrixBSR) = SparseMatrixCSC(ROCSparseMatrixCSR(x)) # no direct conversion
435430
SparseMatrixCSC(x::ROCSparseMatrixCOO) = SparseMatrixCSC(ROCSparseMatrixCSR(x)) # no direct conversion
@@ -519,7 +514,7 @@ Base.copy(Mat::ROCSparseMatrixCOO) = copyto!(similar(Mat), Mat)
519514

520515
# input/output
521516

522-
for (gpu, cpu) in [ROCSparseVector => SparseVector]
517+
for (gpu, cpu) in [:ROCSparseVector => :SparseVector]
523518
@eval function Base.show(io::IO, ::MIME"text/plain", x::$gpu)
524519
xnnz = length(nonzeros(x))
525520
print(io, length(x), "-element ", typeof(x), " with ", xnnz,
@@ -531,10 +526,10 @@ for (gpu, cpu) in [ROCSparseVector => SparseVector]
531526
end
532527
end
533528

534-
for (gpu, cpu) in [ROCSparseMatrixCSC => SparseMatrixCSC,
535-
ROCSparseMatrixCSR => SparseMatrixCSC,
536-
ROCSparseMatrixBSR => SparseMatrixCSC,
537-
ROCSparseMatrixCOO => SparseMatrixCSC]
529+
for (gpu, cpu) in [:ROCSparseMatrixCSC => :SparseMatrixCSC,
530+
:ROCSparseMatrixCSR => :SparseMatrixCSC,
531+
:ROCSparseMatrixBSR => :SparseMatrixCSC,
532+
:ROCSparseMatrixCOO => :SparseMatrixCSC]
538533
@eval Base.show(io::IOContext, x::$gpu) = show(io, $cpu(x))
539534

540535
@eval function Base.show(io::IO, mime::MIME"text/plain", S::$gpu)

src/sparse/conversions.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ function SparseArrays.sparse(
4646
end
4747
end
4848

49+
for SparseMatrixType in (:ROCSparseMatrixCSC, :ROCSparseMatrixCSR, :ROCSparseMatrixCOO)
50+
@eval SparseArrays.sparse(A::$SparseMatrixType) = A
51+
end
52+
4953
function sort_rows(coo::ROCSparseMatrixCOO{Tv,Ti}) where {Tv <: BlasFloat, Ti}
5054
m,n = size(coo)
5155
perm = ROCArray{Ti}(undef, nnz(coo))
@@ -487,3 +491,53 @@ function ROCSparseMatrixBSR(A::ROCMatrix; ind::SparseChar = 'O')
487491
m, n = size(A) # TODO: always let the user choose, or provide defaults for other methods too
488492
ROCSparseMatrixBSR(ROCSparseMatrixCSR(A; ind), gcd(m,n))
489493
end
494+
495+
496+
function AMDGPU.ROCMatrix{T}(coo::ROCSparseMatrixCOO{T}; index::SparseChar='O') where {T}
497+
sparsetodense(coo, index)
498+
end
499+
500+
function ROCSparseMatrixCOO(A::ROCMatrix{T}; index::SparseChar='O') where {T}
501+
densetosparse(A, :coo, index)
502+
end
503+
504+
## ROCSparseVector to ROCSparseMatrices and vice-versa
505+
function ROCSparseVector(A::ROCSparseMatrixCSC{T}) where T
506+
m, n = size(A)
507+
(n == 1) || error("A doesn't have one column and can't be converted to a ROCSparseVector.")
508+
ROCSparseVector{T}(A.rowVal, A.nzVal, m)
509+
end
510+
511+
# no direct conversion
512+
function ROCSparseVector(A::ROCSparseMatrixCSR{T}) where T
513+
m, n = size(A)
514+
(n == 1) || error("A doesn't have one column and can't be converted to a ROCSparseVector.")
515+
B = ROCSparseMatrixCSC{T}(A)
516+
ROCSparseVector(B)
517+
end
518+
519+
function ROCSparseVector(A::ROCSparseMatrixCOO{T}) where T
520+
m, n = size(A)
521+
(n == 1) || error("A doesn't have one column and can't be converted to a ROCSparseVector.")
522+
ROCSparseVector{T}(A.rowInd, A.nzVal, m)
523+
end
524+
525+
function ROCSparseMatrixCSC(x::ROCSparseVector{T}) where T
526+
n = length(x)
527+
colPtr = CuVector{Int32}([1; nnz(x)+1])
528+
ROCSparseMatrixCSC{T}(colPtr, x.iPtr, nonzeros(x), (n,1))
529+
end
530+
531+
# no direct conversion
532+
function ROCSparseMatrixCSR(x::ROCSparseVector{T}) where T
533+
A = ROCSparseMatrixCSC(x)
534+
ROCSparseMatrixCSR{T}(A)
535+
end
536+
537+
function ROCSparseMatrixCOO(x::ROCSparseVector{T}) where T
538+
n = length(x)
539+
nnzx = nnz(x)
540+
colInd = CuVector{Int32}(undef, nnzx)
541+
fill!(colInd, one(Int32))
542+
ROCSparseMatrixCOO{T}(x.iPtr, colInd, nonzeros(x), (n,1), nnzx)
543+
end

0 commit comments

Comments
 (0)