Skip to content

Commit 059eeef

Browse files
committed
Respect zero/one based indexing in COO->CSC/CSR conversions
1 parent 822437c commit 059eeef

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

lib/cusparse/conversions.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -586,10 +586,10 @@ end
586586
## CSR to COO and vice-versa
587587

588588
function CuSparseMatrixCSR{Tv}(coo::CuSparseMatrixCOO{Tv}; index::SparseChar='O') where {Tv}
589-
m,n = size(coo)
590-
nnz(coo) == 0 && return CuSparseMatrixCSR{Tv}(CUDA.ones(Cint, m+1), coo.colInd, nonzeros(coo), size(coo))
589+
m, n = size(coo)
590+
csrRowPtr = (index == 'O') ? CUDA.ones(Cint, m + 1) : CUDA.zeros(Cint, m + 1)
591+
nnz(coo) == 0 && return CuSparseMatrixCSR{Tv}(csrRowPtr, coo.colInd, nonzeros(coo), size(coo))
591592
coo = sort_coo(coo, 'R')
592-
csrRowPtr = CuVector{Cint}(undef, m+1)
593593
cusparseXcoo2csr(handle(), coo.rowInd, nnz(coo), m, csrRowPtr, index)
594594
CuSparseMatrixCSR{Tv}(csrRowPtr, coo.colInd, nonzeros(coo), size(coo))
595595
end
@@ -605,10 +605,10 @@ end
605605
### CSC to COO and viceversa
606606

607607
function CuSparseMatrixCSC{Tv}(coo::CuSparseMatrixCOO{Tv}; index::SparseChar='O') where {Tv}
608-
m,n = size(coo)
609-
nnz(coo) == 0 && return CuSparseMatrixCSC{Tv}(CUDA.ones(Cint, n+1), coo.rowInd, nonzeros(coo), size(coo))
608+
m, n = size(coo)
609+
cscColPtr = (index == 'O') ? CUDA.ones(Cint, n + 1) : CUDA.zeros(Cint, n + 1)
610+
nnz(coo) == 0 && return CuSparseMatrixCSC{Tv}(cscColPtr, coo.rowInd, nonzeros(coo), size(coo))
610611
coo = sort_coo(coo, 'C')
611-
cscColPtr = CuVector{Cint}(undef, n+1)
612612
cusparseXcoo2csr(handle(), coo.colInd, nnz(coo), n, cscColPtr, index)
613613
CuSparseMatrixCSC{Tv}(cscColPtr, coo.rowInd, nonzeros(coo), size(coo))
614614
end

0 commit comments

Comments
 (0)