@@ -586,10 +586,10 @@ end
586
586
# # CSR to COO and vice-versa
587
587
588
588
function CuSparseMatrixCSR {Tv} (coo:: CuSparseMatrixCOO{Tv} ; index:: SparseChar = ' O' ) where {Tv}
589
- m,n = size (coo)
590
- nnz (coo) == 0 && return CuSparseMatrixCSR {Tv} (CUDA. ones (Cint, m+ 1 ), coo. colInd, nonzeros (coo), size (coo))
589
+ m, n = size (coo)
590
+ csrRowPtr = (index == ' O' ) ? CUDA. ones (Cint, m + 1 ) : CUDA. zeros (Cint, m + 1 )
591
+ nnz (coo) == 0 && return CuSparseMatrixCSR {Tv} (csrRowPtr, coo. colInd, nonzeros (coo), size (coo))
591
592
coo = sort_coo (coo, ' R' )
592
- csrRowPtr = CuVector {Cint} (undef, m+ 1 )
593
593
cusparseXcoo2csr (handle (), coo. rowInd, nnz (coo), m, csrRowPtr, index)
594
594
CuSparseMatrixCSR {Tv} (csrRowPtr, coo. colInd, nonzeros (coo), size (coo))
595
595
end
@@ -605,10 +605,10 @@ end
605
605
# ## CSC to COO and viceversa
606
606
607
607
function CuSparseMatrixCSC {Tv} (coo:: CuSparseMatrixCOO{Tv} ; index:: SparseChar = ' O' ) where {Tv}
608
- m,n = size (coo)
609
- nnz (coo) == 0 && return CuSparseMatrixCSC {Tv} (CUDA. ones (Cint, n+ 1 ), coo. rowInd, nonzeros (coo), size (coo))
608
+ m, n = size (coo)
609
+ cscColPtr = (index == ' O' ) ? CUDA. ones (Cint, n + 1 ) : CUDA. zeros (Cint, n + 1 )
610
+ nnz (coo) == 0 && return CuSparseMatrixCSC {Tv} (cscColPtr, coo. rowInd, nonzeros (coo), size (coo))
610
611
coo = sort_coo (coo, ' C' )
611
- cscColPtr = CuVector {Cint} (undef, n+ 1 )
612
612
cusparseXcoo2csr (handle (), coo. colInd, nnz (coo), n, cscColPtr, index)
613
613
CuSparseMatrixCSC {Tv} (cscColPtr, coo. rowInd, nonzeros (coo), size (coo))
614
614
end
0 commit comments