@@ -332,7 +332,7 @@ for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
332
332
@eval begin
333
333
function CuSparseMatrixCSC {$elty} (csr:: CuSparseMatrixCSR{$elty} ; index:: SparseChar = ' O' , action:: cusparseAction_t = CUSPARSE_ACTION_NUMERIC, algo:: cusparseCsr2CscAlg_t = CUSPARSE_CSR2CSC_ALG1)
334
334
m,n = size (csr)
335
- colPtr = CUDA. zeros (Cint, n+ 1 )
335
+ colPtr = (index == ' O ' ) ? CUDA . ones (Cint, n + 1 ) : CUDA. zeros (Cint, n+ 1 )
336
336
rowVal = CUDA. zeros (Cint, nnz (csr))
337
337
nzVal = CUDA. zeros ($ elty, nnz (csr))
338
338
function bufferSize ()
@@ -352,7 +352,7 @@ for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
352
352
353
353
function CuSparseMatrixCSR {$elty} (csc:: CuSparseMatrixCSC{$elty} ; index:: SparseChar = ' O' , action:: cusparseAction_t = CUSPARSE_ACTION_NUMERIC, algo:: cusparseCsr2CscAlg_t = CUSPARSE_CSR2CSC_ALG1)
354
354
m,n = size (csc)
355
- rowPtr = CUDA. zeros (Cint,m+ 1 )
355
+ rowPtr = (index == ' O ' ) ? CUDA. ones (Cint, m + 1 ) : CUDA . zeros (Cint, m+ 1 )
356
356
colVal = CUDA. zeros (Cint,nnz (csc))
357
357
nzVal = CUDA. zeros ($ elty,nnz (csc))
358
358
function bufferSize ()
@@ -379,7 +379,7 @@ for (elty, welty) in ((:Float16, :Float32),
379
379
@eval begin
380
380
function CuSparseMatrixCSC {$elty} (csr:: CuSparseMatrixCSR{$elty} ; index:: SparseChar = ' O' , action:: cusparseAction_t = CUSPARSE_ACTION_NUMERIC, algo:: cusparseCsr2CscAlg_t = CUSPARSE_CSR2CSC_ALG1)
381
381
m,n = size (csr)
382
- colPtr = CUDA. zeros (Cint, n+ 1 )
382
+ colPtr = (index == ' O ' ) ? CUDA . ones (Cint, n + 1 ) : CUDA. zeros (Cint, n+ 1 )
383
383
rowVal = CUDA. zeros (Cint, nnz (csr))
384
384
nzVal = CUDA. zeros ($ elty, nnz (csr))
385
385
if $ elty == Float16 # broken for ComplexF16?
@@ -405,7 +405,7 @@ for (elty, welty) in ((:Float16, :Float32),
405
405
406
406
function CuSparseMatrixCSR {$elty} (csc:: CuSparseMatrixCSC{$elty} ; index:: SparseChar = ' O' , action:: cusparseAction_t = CUSPARSE_ACTION_NUMERIC, algo:: cusparseCsr2CscAlg_t = CUSPARSE_CSR2CSC_ALG1)
407
407
m,n = size (csc)
408
- rowPtr = CUDA. zeros (Cint,m+ 1 )
408
+ rowPtr = (index == ' O ' ) ? CUDA. ones (Cint, m + 1 ) : CUDA . zeros (Cint, m+ 1 )
409
409
colVal = CUDA. zeros (Cint,nnz (csc))
410
410
nzVal = CUDA. zeros ($ elty,nnz (csc))
411
411
if $ elty == Float16 # broken for ComplexF16?
0 commit comments