Skip to content

Commit 779e4dc

Browse files
authored
Merge pull request #2651 from JuliaGPU/ksh/mixedsparse
Re-enable mixed precision sparse mv
2 parents 4067511 + f14d887 commit 779e4dc

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

lib/cusparse/interfaces.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,19 @@ op_wrappers = ((identity, T -> 'N', identity),
6262
(T -> :(HermOrSym{T, <:$T}), T -> 'N', A -> :(parent($A))))
6363

6464
# legacy methods with final MulAddMul argument
65-
LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::CuSparseMatrix{T}, B::DenseCuVector{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} =
65+
LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::CuSparseMatrix{S}, B::DenseCuVector{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}, S <: Union{Float16, ComplexF16, BlasFloat}} =
6666
LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
67-
LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::CuSparseMatrix{T}, B::CuSparseVector{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} =
67+
LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::CuSparseMatrix{S}, B::CuSparseVector{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}, S <: Union{Float16, ComplexF16, BlasFloat}} =
6868
LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
6969
LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::CuSparseMatrix{T}, B::DenseCuMatrix{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} =
7070
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
7171

72-
function LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::CuSparseMatrix{T}, B::DenseCuVector{T}, alpha::Number, beta::Number) where {T <: Union{Float16, ComplexF16, BlasFloat}}
72+
function LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::CuSparseMatrix{S}, B::DenseCuVector{T}, alpha::Number, beta::Number) where {T <: Union{Float16, ComplexF16, BlasFloat}, S <: Union{Float16, ComplexF16, BlasFloat}}
7373
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
7474
mv_wrapper(tA, alpha, A, B, beta, C)
7575
end
7676

77-
function LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::CuSparseMatrix{T}, B::CuSparseVector{T}, alpha::Number, beta::Number) where {T <: Union{Float16, ComplexF16, BlasFloat}}
77+
function LinearAlgebra.generic_matvecmul!(C::CuVector{T}, tA::AbstractChar, A::CuSparseMatrix{S}, B::CuSparseVector{T}, alpha::Number, beta::Number) where {T <: Union{Float16, ComplexF16, BlasFloat}, S <: Union{Float16, ComplexF16, BlasFloat}}
7878
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
7979
mv_wrapper(tA, alpha, A, CuVector{T}(B), beta, C)
8080
end

test/libraries/cusparse/interfaces.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,25 @@ using LinearAlgebra, SparseArrays
128128
end
129129
end
130130

131+
@testset "CuSparseMatrix * CuVector -- mul!(c, A, b) mixed $eltys" for eltys in ((Float32, ComplexF32), (Float64, ComplexF64))
132+
eltya, eltyb = eltys
133+
for opa in (identity, transpose, adjoint)
134+
n = 10
135+
m = 20
136+
A = opa == identity ? sprand(eltya, n, m, 0.5) : sprand(eltya, m, n, 0.5)
137+
b = rand(eltyb, m)
138+
c = rand(eltyb, n)
139+
140+
dA = CuSparseMatrixCSR(A)
141+
db = CuArray(b)
142+
dc = CuArray(c)
143+
144+
mul!(c, opa(A), b)
145+
mul!(dc, opa(dA), db)
146+
@test c collect(dc)
147+
end
148+
end
149+
131150
for SparseMatrixType in (CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO, CuSparseMatrixBSR)
132151

133152
if CUSPARSE.version() >= v"11.7.4"
@@ -565,4 +584,4 @@ end
565584

566585
@test ref_cuda_sparse.colPtr == cuda_spdiagm.colPtr
567586
end
568-
end
587+
end

0 commit comments

Comments
 (0)