Skip to content

Commit 78c1036

Browse files
authored
Update ROCSparse for Julia v1.10 (#613)
* [rocSPARSE] Update the interface for sparse products * Fix test for rocsparse/interfaces.jl * Fix again the tests for rocsparse/interfaces.jl
1 parent e61e088 commit 78c1036

File tree

2 files changed

+63
-83
lines changed

2 files changed

+63
-83
lines changed

src/sparse/interfaces.jl

Lines changed: 46 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -22,48 +22,52 @@ function mm_wrapper(
2222
mm!(transa, transb, alpha, A, B, beta, C, 'O')
2323
end
2424

25-
tag_wrappers = (
26-
(identity, identity),
27-
(T -> :(HermOrSym{T, <:$T}), A -> :(parent($A))))
28-
29-
op_wrappers = (
30-
(identity, T -> 'N', identity),
31-
(T -> :(Transpose{<:T, <:$T}), T -> 'T', A -> :(parent($A))),
32-
(T -> :(Adjoint{<:T, <:$T}), T -> T <: Real ? 'T' : 'C', A -> :(parent($A))))
33-
34-
for (taga, untaga) in tag_wrappers, (wrapa, transa, unwrapa) in op_wrappers
35-
TypeA = wrapa(taga(:(ROCSparseMatrix{T})))
36-
37-
@eval begin
38-
function LinearAlgebra.mul!(
39-
C::ROCVector{T}, A::$TypeA, B::DenseROCVector{T},
40-
alpha::Number, beta::Number,
41-
) where T <: Union{Float16, ComplexF16, BlasFloat}
42-
mv_wrapper($transa(T), alpha, $(untaga(unwrapa(:A))), B, beta, C)
43-
end
44-
45-
function LinearAlgebra.mul!(
46-
C::ROCVector{Complex{T}}, A::$TypeA, B::DenseROCVector{Complex{T}},
47-
alpha::Number, beta::Number,
48-
) where T <: Union{Float16, BlasFloat}
49-
mv_wrapper($transa(T), alpha, $(untaga(unwrapa(:A))), B, beta, C)
50-
end
51-
end
52-
53-
for (tagb, untagb) in tag_wrappers, (wrapb, transb, unwrapb) in op_wrappers
54-
TypeB = wrapb(tagb(:(DenseROCMatrix{T})))
55-
56-
@eval begin
57-
function LinearAlgebra.mul!(
58-
C::ROCMatrix{T}, A::$TypeA, B::$TypeB,
59-
alpha::Number, beta::Number,
60-
) where T <: Union{Float16, ComplexF16, BlasFloat}
61-
mm_wrapper(
62-
$transa(T), $transb(T), alpha,
63-
$(untaga(unwrapa(:A))), $(untagb(unwrapb(:B))), beta, C)
64-
end
65-
end
66-
end
25+
# legacy methods with final MulAddMul argument
26+
LinearAlgebra.generic_matvecmul!(C::ROCVector{T}, tA::AbstractChar, A::ROCSparseMatrix{T}, B::DenseROCVector{T}, _add::MulAddMul) where T <: BlasFloat =
27+
LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
28+
LinearAlgebra.generic_matvecmul!(C::ROCVector{T}, tA::AbstractChar, A::ROCSparseMatrix{T}, B::ROCSparseVector{T}, _add::MulAddMul) where T <: BlasFloat =
29+
LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
30+
LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::ROCSparseMatrix{T}, B::DenseROCMatrix{T}, _add::MulAddMul) where T <: BlasFloat =
31+
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
32+
33+
function LinearAlgebra.generic_matvecmul!(C::ROCVector{T}, tA::AbstractChar, A::ROCSparseMatrix{T}, B::DenseROCVector{T}, alpha::Number, beta::Number) where T <: BlasFloat
34+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
35+
mv_wrapper(tA, alpha, A, B, beta, C)
36+
end
37+
38+
function LinearAlgebra.generic_matvecmul!(C::ROCVector{T}, tA::AbstractChar, A::ROCSparseMatrix{T}, B::ROCSparseVector{T}, alpha::Number, beta::Number) where T <: BlasFloat
39+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
40+
mv_wrapper(tA, alpha, A, ROCVector{T}(B), beta, C)
41+
end
42+
43+
function LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::ROCSparseMatrix{T}, B::DenseROCMatrix{T}, alpha::Number, beta::Number) where T <: BlasFloat
44+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
45+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
46+
mm_wrapper(tA, tB, alpha, A, B, beta, C)
47+
end
48+
49+
# legacy methods with final MulAddMul argument
50+
LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCSC{T}, _add::MulAddMul) where T <: BlasFloat =
51+
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
52+
LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCSR{T}, _add::MulAddMul) where T <: BlasFloat =
53+
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
54+
LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCOO{T}, _add::MulAddMul) where T <: BlasFloat =
55+
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
56+
57+
function LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCSC{T}, alpha::Number, beta::Number) where T <: BlasFloat
58+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
59+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
60+
mm!(tA, tB, alpha, A, B, beta, C, 'O')
61+
end
62+
function LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCSR{T}, alpha::Number, beta::Number) where T <: BlasFloat
63+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
64+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
65+
mm!(tA, tB, alpha, A, B, beta, C, 'O')
66+
end
67+
function LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCOO{T}, alpha::Number, beta::Number) where T <: BlasFloat
68+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
69+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
70+
mm!(tA, tB, alpha, A, B, beta, C, 'O')
6771
end
6872

6973
Base.:(+)(A::ROCSparseMatrixCSR, B::ROCSparseMatrixCSR) = geam(one(eltype(A)), A, one(eltype(A)), B, 'O')

test/rocsparse/interfaces.jl

Lines changed: 17 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -79,47 +79,23 @@
7979
LinearAlgebra.mul!(dc, f(dA), db, alpha, beta)
8080
@test c collect(dc)
8181

82-
A = A + transpose(A)
83-
dA = ROCSparseMatrixCSR(A)
84-
85-
@assert issymmetric(A)
86-
LinearAlgebra.mul!(c, f(Symmetric(A)), b, alpha, beta)
87-
LinearAlgebra.mul!(dc, f(Symmetric(dA)), db, alpha, beta)
88-
@test c collect(dc)
89-
end
90-
91-
@testset "$f(A)*b Complex{$elty}*$elty" for elty in (
92-
Float32, Float64,
93-
), f in (
94-
identity, transpose, adjoint,
95-
)
96-
n = 10
97-
alpha = rand()
98-
beta = rand()
99-
A = sprand(Complex{elty}, n, n, rand())
100-
b = rand(Complex{elty}, n)
101-
c = rand(Complex{elty}, n)
102-
alpha = beta = 1.0
103-
c = zeros(Complex{elty}, n)
104-
105-
dA = ROCSparseMatrixCSR(A)
106-
db = ROCArray(b)
107-
dc = ROCArray(c)
108-
109-
# test with empty inputs
110-
@test Array(dA * AMDGPU.zeros(Complex{elty}, n, 0)) == zeros(Complex{elty}, n, 0)
111-
112-
LinearAlgebra.mul!(c, f(A), b, alpha, beta)
113-
LinearAlgebra.mul!(dc, f(dA), db, alpha, beta)
114-
@test c collect(dc)
115-
116-
A = A + transpose(A)
117-
dA = ROCSparseMatrixCSR(A)
118-
119-
@assert issymmetric(A)
120-
LinearAlgebra.mul!(c, f(Symmetric(A)), b, alpha, beta)
121-
LinearAlgebra.mul!(dc, f(Symmetric(dA)), db, alpha, beta)
122-
@test c collect(dc)
82+
if f in (identity, transpose)
83+
A = A + transpose(A)
84+
dA = ROCSparseMatrixCSR(A)
85+
86+
@assert issymmetric(A)
87+
LinearAlgebra.mul!(c, f(Symmetric(A)), b, alpha, beta)
88+
LinearAlgebra.mul!(dc, f(Symmetric(dA)), db, alpha, beta)
89+
@test c collect(dc)
90+
else
91+
A = A + adjoint(A)
92+
dA = ROCSparseMatrixCSR(A)
93+
94+
@assert ishermitian(A)
95+
LinearAlgebra.mul!(c, f(Hermitian(A)), b, alpha, beta)
96+
LinearAlgebra.mul!(dc, f(Hermitian(dA)), db, alpha, beta)
97+
@test c collect(dc)
98+
end
12399
end
124100

125101
@testset "$f(A)*$h(B) $elty" for elty in (

0 commit comments

Comments
 (0)