@@ -16,13 +16,36 @@ SparseArrays.nnz(g::T) where {T<:AbstractGPUSparseArray} = g.nnz
1616SparseArrays. nonzeros(g:: T ) where {T<: AbstractGPUSparseArray } = g. nzVal
1717
1818SparseArrays. nonzeroinds(g:: T ) where {T<: AbstractGPUSparseVector } = g. iPtr
19- SparseArrays. rowvals(g:: T ) where {T<: AbstractGPUSparseVector } = nonzeroinds(g)
19+ SparseArrays. rowvals(g:: T ) where {T<: AbstractGPUSparseVector } = SparseArrays . nonzeroinds(g)
2020
2121SparseArrays. rowvals(g:: AbstractGPUSparseMatrixCSC ) = g. rowVal
2222SparseArrays. getcolptr(S:: AbstractGPUSparseMatrixCSC ) = S. colPtr
2323
2424Base. convert(T:: Type{<:AbstractGPUSparseArray} , m:: AbstractArray ) = m isa T ? m : T(m)
2525
26+ # collect to Array
27+ Base. collect(x:: AbstractGPUSparseVector ) = collect(SparseVector(x))
28+ Base. collect(x:: AbstractGPUSparseMatrixCSC ) = collect(SparseMatrixCSC(x))
29+ Base. collect(x:: AbstractGPUSparseMatrixCSR ) = collect(SparseMatrixCSC(x))
30+ Base. collect(x:: AbstractGPUSparseMatrixBSR ) = collect(SparseMatrixCSC(x))
31+ Base. collect(x:: AbstractGPUSparseMatrixCOO ) = collect(SparseMatrixCSC(x))
32+
33+ Base. Array(x:: AbstractGPUSparseVector ) = collect(SparseVector(x))
34+ Base. Array(x:: AbstractGPUSparseMatrixCSC ) = collect(SparseMatrixCSC(x))
35+ Base. Array(x:: AbstractGPUSparseMatrixCSR ) = collect(SparseMatrixCSC(x))
36+ Base. Array(x:: AbstractGPUSparseMatrixBSR ) = collect(SparseMatrixCSC(x))
37+ Base. Array(x:: AbstractGPUSparseMatrixCOO ) = collect(SparseMatrixCSC(x))
38+
39+ SparseArrays. SparseVector(x:: AbstractGPUSparseVector ) = SparseVector(length(x), Array(SparseArrays. nonzeroinds(x)), Array(SparseArrays. nonzeros(x)))
40+ SparseArrays. SparseMatrixCSC(x:: AbstractGPUSparseMatrixCSC ) = SparseMatrixCSC(size(x). .. , Array(SparseArrays. getcolptr(x)), Array(SparseArrays. rowvals(x)), Array(SparseArrays. nonzeros(x)))
41+
42+ # similar
43+ Base. similar(Vec:: V ) where {V<: AbstractGPUSparseVector } = V(copy(SparseArrays. nonzeroinds(Vec)), similar(SparseArrays. nonzeros(Vec)), length(Vec))
44+ Base. similar(Mat:: M ) where {M<: AbstractGPUSparseMatrixCSC } = M(copy(SparseArrays. getcolptr(Mat)), copy(SparseArrays. rowvals(Mat)), similar(SparseArrays. nonzeros(Mat)), size(Mat))
45+
46+ Base. similar(Vec:: V , T:: Type ) where {Tv, Ti, V<: AbstractGPUSparseVector{Tv, Ti} } = sparse_array_type(V){T, Ti}(copy(SparseArrays. nonzeroinds(Vec)), similar(SparseArrays. nonzeros(Vec), T), length(Vec))
47+ Base. similar(Mat:: M , T:: Type ) where {M<: AbstractGPUSparseMatrixCSC } = sparse_array_type(M)(copy(SparseArrays. getcolptr(Mat)), copy(SparseArrays. rowvals(Mat)), similar(SparseArrays. nonzeros(Mat), T), size(Mat))
48+
2649dense_array_type(sa:: SparseVector ) = SparseVector
2750dense_array_type(:: Type{SparseVector} ) = SparseVector
2851sparse_array_type(sa:: SparseVector ) = SparseVector
@@ -207,6 +230,52 @@ Base.getindex(A::AbstractGPUSparseMatrix, i, ::Colon) = getindex(A, i, 1:s
207230Base. getindex(A:: AbstractGPUSparseMatrix , :: Colon , i) = getindex(A, 1 : size(A, 1 ), i)
208231Base. getindex(A:: AbstractGPUSparseMatrix , I:: Tuple{Integer,Integer} ) = getindex(A, I[1 ], I[2 ])
209232
233+ function Base. getindex(A:: AbstractGPUSparseVector{Tv, Ti} , i:: Integer ) where {Tv, Ti}
234+ @boundscheck checkbounds(A, i)
235+ ii = searchsortedfirst(SparseArrays. nonzeroinds(A), convert(Ti, i))
236+ (ii > SparseArrays. nnz(A) || SparseArrays. nonzeroinds(A)[ii] != i) && return zero(Tv)
237+ SparseArrays. nonzeros(A)[ii]
238+ end
239+
240+ function Base. getindex(A:: AbstractGPUSparseMatrixCSC{T} , i0:: Integer , i1:: Integer ) where T
241+ @boundscheck checkbounds(A, i0, i1)
242+ r1 = Int(SparseArrays. getcolptr(A)[i1])
243+ r2 = Int(SparseArrays. getcolptr(A)[i1+ 1 ]- 1 )
244+ (r1 > r2) && return zero(T)
245+ r1 = searchsortedfirst(SparseArrays. rowvals(A), i0, r1, r2, Base. Order. Forward)
246+ (r1 > r2 || SparseArrays. rowvals(A)[r1] != i0) && return zero(T)
247+ SparseArrays. nonzeros(A)[r1]
248+ end
249+
250+ # # copying between sparse GPU arrays
251+ Base. copy(Vec:: AbstractGPUSparseVector ) = copyto!(similar(Vec), Vec)
252+
253+ function Base. copyto!(dst:: AbstractGPUSparseVector , src:: AbstractGPUSparseVector )
254+ if length(dst) != length(src)
255+ throw(ArgumentError(" Inconsistent Sparse Vector size" ))
256+ end
257+ resize!(SparseArrays. nonzeroinds(dst), length(SparseArrays. nonzeroinds(src)))
258+ resize!(SparseArrays. nonzeros(dst), length(SparseArrays. nonzeros(src)))
259+ copyto!(SparseArrays. nonzeroinds(dst), SparseArrays. nonzeroinds(src))
260+ copyto!(SparseArrays. nonzeros(dst), SparseArrays. nonzeros(src))
261+ dst. nnz = src. nnz
262+ dst
263+ end
264+
265+ function Base. copyto!(dst:: AbstractGPUSparseMatrixCSC , src:: AbstractGPUSparseMatrixCSC )
266+ if size(dst) != size(src)
267+ throw(ArgumentError(" Inconsistent Sparse Matrix size" ))
268+ end
269+ resize!(SparseArrays. getcolptr(dst), length(SparseArrays. getcolptr(src)))
270+ resize!(SparseArrays. rowvals(dst), length(SparseArrays. rowvals(src)))
271+ resize!(SparseArrays. nonzeros(dst), length(SparseArrays. nonzeros(src)))
272+ copyto!(SparseArrays. getcolptr(dst), SparseArrays. getcolptr(src))
273+ copyto!(SparseArrays. rowvals(dst), SparseArrays. rowvals(src))
274+ copyto!(SparseArrays. nonzeros(dst), SparseArrays. nonzeros(src))
275+ dst. nnz = src. nnz
276+ dst
277+ end
278+
210279# ## BROADCAST
211280
212281# broadcast container type promotion for combinations of sparse arrays and other types
@@ -749,12 +818,12 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
749818 offsets = rowPtr = sparse_arg. rowPtr
750819 colVal = similar(sparse_arg. colVal)
751820 nzVal = similar(sparse_arg. nzVal, Tv)
752- output = _sparse_array_type (sparse_typ)(rowPtr, colVal, nzVal, size(bc))
821+ output = sparse_array_type (sparse_typ)(rowPtr, colVal, nzVal, size(bc))
753822 elseif sparse_typ <: AbstractGPUSparseMatrixCSC
754823 offsets = colPtr = sparse_arg. colPtr
755824 rowVal = similar(sparse_arg. rowVal)
756825 nzVal = similar(sparse_arg. nzVal, Tv)
757- output = _sparse_array_type (sparse_typ)(colPtr, rowVal, nzVal, size(bc))
826+ output = sparse_array_type (sparse_typ)(colPtr, rowVal, nzVal, size(bc))
758827 end
759828 else
760829 # determine the number of non-zero elements per row so that we can create an
@@ -803,15 +872,15 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
803872 output = if sparse_typ <: Union{AbstractGPUSparseMatrixCSR,AbstractGPUSparseMatrixCSC}
804873 ixVal = similar(offsets, Ti, total_nnz)
805874 nzVal = similar(offsets, Tv, total_nnz)
806- output_sparse_typ = _sparse_array_type (sparse_typ)
875+ output_sparse_typ = sparse_array_type (sparse_typ)
807876 output_sparse_typ(offsets, ixVal, nzVal, size(bc))
808877 elseif sparse_typ <: AbstractGPUSparseVector && ! fpreszeros
809878 val_array = bc. args[first(sparse_args)]. nzVal
810879 similar(val_array, Tv, size(bc))
811880 elseif sparse_typ <: AbstractGPUSparseVector && fpreszeros
812881 iPtr = similar(offsets, Ti, total_nnz)
813882 nzVal = similar(offsets, Tv, total_nnz)
814- _sparse_array_type (sparse_arg){Tv, Ti}(iPtr, nzVal, rows)
883+ sparse_array_type (sparse_arg){Tv, Ti}(iPtr, nzVal, rows)
815884 end
816885 if sparse_typ <: AbstractGPUSparseVector && ! fpreszeros
817886 nonsparse_args = map(bc. args) do arg
@@ -932,9 +1001,9 @@ function Base.mapreduce(f, op, A::AbstractGPUSparseMatrix; dims=:, init=nothing)
9321001 in(dims, [Colon(), 1 , 2 ]) || error(" only dims=:, dims=1 or dims=2 is supported" )
9331002
9341003 if A isa AbstractGPUSparseMatrixCSR && dims == 1
935- A = _csc_type (A)(A)
1004+ A = csc_type (A)(A)
9361005 elseif A isa AbstractGPUSparseMatrixCSC && dims == 2
937- A = _csr_type (A)(A)
1006+ A = csr_type (A)(A)
9381007 end
9391008 m, n = size(A)
9401009 val_array = nonzeros(A)
0 commit comments