Skip to content

Commit 19fc98b

Browse files
tam724maleadt
authored andcommitted
fix conversion of 0x0 CuSparseMatrixCSC <-> CSR
1 parent 205c238 commit 19fc98b

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

lib/cusparse/conversions.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
332332
@eval begin
333333
function CuSparseMatrixCSC{$elty}(csr::CuSparseMatrixCSR{$elty}; index::SparseChar='O', action::cusparseAction_t=CUSPARSE_ACTION_NUMERIC, algo::cusparseCsr2CscAlg_t=CUSPARSE_CSR2CSC_ALG1)
334334
m,n = size(csr)
335-
colPtr = CUDA.zeros(Cint, n+1)
335+
colPtr = (index == 'O') ? CUDA.ones(Cint, n+1) : CUDA.zeros(Cint, n+1)
336336
rowVal = CUDA.zeros(Cint, nnz(csr))
337337
nzVal = CUDA.zeros($elty, nnz(csr))
338338
function bufferSize()
@@ -352,7 +352,7 @@ for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
352352

353353
function CuSparseMatrixCSR{$elty}(csc::CuSparseMatrixCSC{$elty}; index::SparseChar='O', action::cusparseAction_t=CUSPARSE_ACTION_NUMERIC, algo::cusparseCsr2CscAlg_t=CUSPARSE_CSR2CSC_ALG1)
354354
m,n = size(csc)
355-
rowPtr = CUDA.zeros(Cint,m+1)
355+
rowPtr = (index == 'O') ? CUDA.ones(Cint, m+1) : CUDA.zeros(Cint, m+1)
356356
colVal = CUDA.zeros(Cint,nnz(csc))
357357
nzVal = CUDA.zeros($elty,nnz(csc))
358358
function bufferSize()
@@ -379,7 +379,7 @@ for (elty, welty) in ((:Float16, :Float32),
379379
@eval begin
380380
function CuSparseMatrixCSC{$elty}(csr::CuSparseMatrixCSR{$elty}; index::SparseChar='O', action::cusparseAction_t=CUSPARSE_ACTION_NUMERIC, algo::cusparseCsr2CscAlg_t=CUSPARSE_CSR2CSC_ALG1)
381381
m,n = size(csr)
382-
colPtr = CUDA.zeros(Cint, n+1)
382+
colPtr = (index == 'O') ? CUDA.ones(Cint, n+1) : CUDA.zeros(Cint, n+1)
383383
rowVal = CUDA.zeros(Cint, nnz(csr))
384384
nzVal = CUDA.zeros($elty, nnz(csr))
385385
if $elty == Float16 #broken for ComplexF16?
@@ -405,7 +405,7 @@ for (elty, welty) in ((:Float16, :Float32),
405405

406406
function CuSparseMatrixCSR{$elty}(csc::CuSparseMatrixCSC{$elty}; index::SparseChar='O', action::cusparseAction_t=CUSPARSE_ACTION_NUMERIC, algo::cusparseCsr2CscAlg_t=CUSPARSE_CSR2CSC_ALG1)
407407
m,n = size(csc)
408-
rowPtr = CUDA.zeros(Cint,m+1)
408+
rowPtr = (index == 'O') ? CUDA.ones(Cint, m+1) : CUDA.zeros(Cint, m+1)
409409
colVal = CUDA.zeros(Cint,nnz(csc))
410410
nzVal = CUDA.zeros($elty,nnz(csc))
411411
if $elty == Float16 #broken for ComplexF16?

test/libraries/cusparse/conversions.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,17 +113,19 @@ if !(v"12.0" <= CUSPARSE.version() < v"12.1")
113113
end
114114
end
115115

116-
for (n, bd, p) in [(100, 5, 0.02), (5, 1, 0.8), (4, 2, 0.5)]
116+
for (n, bd, p) in [(100, 5, 0.02), (5, 1, 0.8), (4, 2, 0.5), (0, 1, 0.0)]
117117
v"12.0" <= CUSPARSE.version() < v"12.1" && n == 4 && continue
118118
@testset "conversions between CuSparseMatrices (n, bd, p) = ($n, $bd, $p)" begin
119119
A = sprand(n, n, p)
120120
blockdim = bd
121121
for CuSparseMatrixType1 in (CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO, CuSparseMatrixBSR)
122+
if CuSparseMatrixType1 == CuSparseMatrixBSR && n == 0 continue end # TODO: conversion to CuSparseMatrixBSR breaks with (0x0) matrices
122123
dA1 = CuSparseMatrixType1 == CuSparseMatrixBSR ? CuSparseMatrixType1(A, blockdim) : CuSparseMatrixType1(A)
123124
@testset "conversion $CuSparseMatrixType1 --> SparseMatrixCSC" begin
124125
@test SparseMatrixCSC(dA1) A
125126
end
126127
for CuSparseMatrixType2 in (CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO, CuSparseMatrixBSR)
128+
if CuSparseMatrixType2 == CuSparseMatrixBSR && n == 0 continue end # TODO: conversion to CuSparseMatrixBSR breaks with (0x0) matrices
127129
CuSparseMatrixType1 == CuSparseMatrixType2 && continue
128130
dA2 = CuSparseMatrixType2 == CuSparseMatrixBSR ? CuSparseMatrixType2(dA1, blockdim) : CuSparseMatrixType2(dA1)
129131
@testset "conversion $CuSparseMatrixType1 --> $CuSparseMatrixType2" begin

0 commit comments

Comments
 (0)