Skip to content

Commit 0be4f7b

Browse files
authored
CUSPARSE: Bugfixes for sparse vector broadcast. (#2780)
1 parent 2b3d0e6 commit 0be4f7b

File tree

2 files changed

+33
-15
lines changed

2 files changed

+33
-15
lines changed

lib/cusparse/broadcast.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,11 +383,12 @@ function sparse_to_sparse_broadcast_kernel(f::F, output::CuSparseDeviceVector{Tv
383383
row = @inbounds row_and_ptrs[1]
384384
arg_ptrs = @inbounds row_and_ptrs[2]
385385
vals = ntuple(Val(N)) do i
386+
@inline
386387
arg = @inbounds args[i]
387388
# ptr is 0 if the sparse vector doesn't have an element at this row
388389
# ptr is 0 if the arg is a scalar AND f preserves zeros
389390
ptr = @inbounds arg_ptrs[i]
390-
_getindex(arg, row, ptr)::Tv
391+
_getindex(arg, row, ptr)
391392
end
392393
output_val = f(vals...)
393394
@inbounds output.iPtr[row_ix] = row
@@ -470,12 +471,13 @@ function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseVector}, f::F,
470471
row = @inbounds row_and_ptrs[1]
471472
arg_ptrs = @inbounds row_and_ptrs[2]
472473
vals = ntuple(Val(length(args))) do i
474+
@inline
473475
arg = @inbounds args[i]
474476
# ptr is 0 if the sparse vector doesn't have an element at this row
475477
# ptr is row if the arg is dense OR a scalar with non-zero-preserving f
476478
# ptr is 0 if the arg is a scalar AND f preserves zeros
477479
ptr = @inbounds arg_ptrs[i]
478-
_getindex(arg, row, ptr)::Tv
480+
_getindex(arg, row, ptr)
479481
end
480482
@inbounds output[row] = f(vals...)
481483
return

test/libraries/cusparse/broadcast.jl

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ using CUDA.CUSPARSE, SparseArrays
3232
dz = dx .* dy .* elty(2)
3333
@test dz isa typ{elty}
3434
@test z == SparseMatrixCSC(dz)
35-
35+
3636
# multiple inputs
3737
w = sprand(elty, m, n, p)
3838
dw = typ(w)
@@ -42,34 +42,34 @@ using CUDA.CUSPARSE, SparseArrays
4242
@test z == SparseMatrixCSC(dz)
4343
end
4444
@testset "$typ($elty)" for typ in [CuSparseVector,]
45-
m = 64
45+
m = 64
4646
p = 0.5
4747
x = sprand(elty, m, p)
4848
dx = typ(x)
49-
49+
5050
# zero-preserving
5151
y = x .* elty(1)
5252
dy = dx .* elty(1)
5353
@test dy isa typ{elty}
54-
@test collect(dy.iPtr) == collect(dx.iPtr)
54+
@test collect(dy.iPtr) == collect(dx.iPtr)
5555
@test collect(dy.iPtr) == y.nzind
5656
@test collect(dy.nzVal) == y.nzval
5757
@test y == SparseVector(dy)
58-
58+
5959
# not zero-preserving
6060
y = x .+ elty(1)
6161
dy = dx .+ elty(1)
6262
@test dy isa CuArray{elty}
6363
hy = Array(dy)
64-
@test Array(y) == hy
64+
@test Array(y) == hy
6565

6666
# involving something dense
6767
y = x .+ ones(elty, m)
6868
dy = dx .+ CUDA.ones(elty, m)
6969
@test dy isa CuArray{elty}
7070
@test y == Array(dy)
71-
72-
# sparse to sparse
71+
72+
# sparse to sparse
7373
dx = typ(x)
7474
y = sprand(elty, m, p)
7575
dy = typ(y)
@@ -88,25 +88,41 @@ using CUDA.CUSPARSE, SparseArrays
8888
dz = @. dx * dy * dw
8989
@test dz isa typ{elty}
9090
@test z == SparseVector(dz)
91-
91+
9292
y = sprand(elty, m, p)
9393
w = sprand(elty, m, p)
9494
dense_arr = rand(elty, m)
95-
d_dense_arr = CuArray(dense_arr)
95+
d_dense_arr = CuArray(dense_arr)
9696
dy = typ(y)
9797
dw = typ(w)
98-
z = @. x * y * w * dense_arr
99-
dz = @. dx * dy * dw * d_dense_arr
98+
z = @. x * y * w * dense_arr
99+
dz = @. dx * dy * dw * d_dense_arr
100100
@test dz isa CuArray{elty}
101101
@test z == Array(dz)
102-
102+
103103
y = sprand(elty, m, p)
104104
dy = typ(y)
105105
dx = typ(x)
106106
z = x .* y .* elty(2)
107107
dz = dx .* dy .* elty(2)
108108
@test dz isa typ{elty}
109109
@test z == SparseVector(dz)
110+
111+
# type-mismatching
112+
## non-zero-preserving
113+
dx = typ(x)
114+
dy = dx .+ 1
115+
y = x .+ 1
116+
@test dy isa CuArray{promote_type(elty, Int)}
117+
@test y == Array(dy)
118+
## zero-preserving
119+
dy = dx .* 1
120+
y = x .* 1
121+
@test dy isa typ{promote_type(elty, Int)}
122+
@test collect(dy.iPtr) == collect(dx.iPtr)
123+
@test collect(dy.iPtr) == y.nzind
124+
@test collect(dy.nzVal) == y.nzval
125+
@test y == SparseVector(dy)
110126
end
111127
end
112128

0 commit comments

Comments
 (0)