Skip to content

Commit e656414

Browse files
authored
More tests and a matmatmul fix (#2697)
1 parent 4df814c commit e656414

File tree

3 files changed

+20
-2
lines changed

3 files changed

+20
-2
lines changed

lib/cusparse/interfaces.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,9 @@ function LinearAlgebra.generic_matmatmul!(C::CuSparseMatrixCOO{T}, tA, tB, A::Cu
177177
A_csr = CuSparseMatrixCSR(A)
178178
B_csr = CuSparseMatrixCSR(B)
179179
C_csr = CuSparseMatrixCSR(C)
180-
generic_matmatmul!(C_csr, tA, tB, A_csr, B_csr, alpha, beta)
181-
C = CuSparseMatrixCOO(C_csr) # is this in-place of the original C?
180+
LinearAlgebra.generic_matmatmul!(C_csr, tA, tB, A_csr, B_csr, alpha, beta)
181+
copyto!(C, CuSparseMatrixCOO(C_csr))
182+
return C
182183
end
183184

184185
for SparseMatrixType in (:CuSparseMatrixCSC, :CuSparseMatrixCSR)

test/libraries/cusparse.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,14 @@ blockdim = 5
101101
d_x = CuSparseMatrixCOO(x)
102102
d_tx = CuSparseMatrixCOO(transpose(x))
103103
d_ax = CuSparseMatrixCOO(adjoint(x))
104+
d_tcx = CuSparseMatrixCOO(transpose(CuSparseMatrixCSC(x)))
105+
d_acx = CuSparseMatrixCOO(adjoint(CuSparseMatrixCSC(x)))
106+
# reordered I, J to test other indexing path
107+
d_rx = CuSparseMatrixCOO{eltype(d_x), Int32}(copy(d_x.colInd), copy(d_x.rowInd), copy(d_x.nzVal))
104108
@test CuSparseMatrixCOO(d_x) === d_x
105109
@test length(d_x) == m*n
106110
@test size(d_x) == (m,n)
111+
@test size(d_rx) == (n,m)
107112
@test size(d_x,1) == m
108113
@test size(d_x,2) == n
109114
@test size(d_x,3) == 1
@@ -119,6 +124,10 @@ blockdim = 5
119124
@test d_x[end] == x[end]
120125
@test d_tx[:, 1] == transpose(x)[:, 1]
121126
@test d_ax[1, :] == adjoint(x)[1, :]
127+
@test d_tcx[:, 1] == transpose(x)[:, 1]
128+
@test d_acx[1, :] == adjoint(x)[1, :]
129+
@test d_rx[:, 1] == transpose(x)[:, 1]
130+
@test d_rx[1, :] == transpose(x)[1, :]
122131
@test d_x[firstindex(d_x), firstindex(d_x)] == x[firstindex(x), firstindex(x)]
123132
@test d_x[div(end, 2), div(end, 2)] == x[div(end, 2), div(end, 2)]
124133
@test d_x[end, end] == x[end, end]

test/libraries/cusparse/interfaces.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@ using LinearAlgebra, SparseArrays
8585
C = opa(A) * opb(B)
8686
dC = opa(dA) * opb(dB)
8787
@test C collect(dC)
88+
if opa == opb == identity
89+
dA = SparseMatrixType(A)
90+
dB = SparseMatrixType(B)
91+
mul!(dC, opa(dA), opb(dB), 3, 3.2)
92+
C = 3.2 * C + 3 * opa(A) * opb(B)
93+
@show SparseMatrixType
94+
@test collect(dC) C
95+
end
8896
end
8997
end
9098
end

0 commit comments

Comments
 (0)