You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
mv_wrapper($transa(T), alpha, $(untaga(unwrapa(:A))), B, beta, C)
50
-
end
51
-
end
52
-
53
-
for (tagb, untagb) in tag_wrappers, (wrapb, transb, unwrapb) in op_wrappers
54
-
TypeB =wrapb(tagb(:(DenseROCMatrix{T})))
55
-
56
-
@evalbegin
57
-
function LinearAlgebra.mul!(
58
-
C::ROCMatrix{T}, A::$TypeA, B::$TypeB,
59
-
alpha::Number, beta::Number,
60
-
) where T <:Union{Float16, ComplexF16, BlasFloat}
61
-
mm_wrapper(
62
-
$transa(T), $transb(T), alpha,
63
-
$(untaga(unwrapa(:A))), $(untagb(unwrapb(:B))), beta, C)
64
-
end
65
-
end
66
-
end
25
+
# legacy methods with final MulAddMul argument
26
+
LinearAlgebra.generic_matvecmul!(C::ROCVector{T}, tA::AbstractChar, A::ROCSparseMatrix{T}, B::DenseROCVector{T}, _add::MulAddMul) where T <:BlasFloat=
27
+
LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
28
+
LinearAlgebra.generic_matvecmul!(C::ROCVector{T}, tA::AbstractChar, A::ROCSparseMatrix{T}, B::ROCSparseVector{T}, _add::MulAddMul) where T <:BlasFloat=
29
+
LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
30
+
LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::ROCSparseMatrix{T}, B::DenseROCMatrix{T}, _add::MulAddMul) where T <:BlasFloat=
31
+
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
32
+
33
+
function LinearAlgebra.generic_matvecmul!(C::ROCVector{T}, tA::AbstractChar, A::ROCSparseMatrix{T}, B::DenseROCVector{T}, alpha::Number, beta::Number) where T <:BlasFloat
34
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
35
+
mv_wrapper(tA, alpha, A, B, beta, C)
36
+
end
37
+
38
+
function LinearAlgebra.generic_matvecmul!(C::ROCVector{T}, tA::AbstractChar, A::ROCSparseMatrix{T}, B::ROCSparseVector{T}, alpha::Number, beta::Number) where T <:BlasFloat
39
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
40
+
mv_wrapper(tA, alpha, A, ROCVector{T}(B), beta, C)
41
+
end
42
+
43
+
function LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::ROCSparseMatrix{T}, B::DenseROCMatrix{T}, alpha::Number, beta::Number) where T <:BlasFloat
44
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
45
+
tB = tB in ('S', 's', 'H', 'h') ?'N': tB
46
+
mm_wrapper(tA, tB, alpha, A, B, beta, C)
47
+
end
48
+
49
+
# legacy methods with final MulAddMul argument
50
+
LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCSC{T}, _add::MulAddMul) where T <:BlasFloat=
51
+
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
52
+
LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCSR{T}, _add::MulAddMul) where T <:BlasFloat=
53
+
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
54
+
LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCOO{T}, _add::MulAddMul) where T <:BlasFloat=
55
+
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
56
+
57
+
function LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCSC{T}, alpha::Number, beta::Number) where T <:BlasFloat
58
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
59
+
tB = tB in ('S', 's', 'H', 'h') ?'N': tB
60
+
mm!(tA, tB, alpha, A, B, beta, C, 'O')
61
+
end
62
+
function LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCSR{T}, alpha::Number, beta::Number) where T <:BlasFloat
63
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
64
+
tB = tB in ('S', 's', 'H', 'h') ?'N': tB
65
+
mm!(tA, tB, alpha, A, B, beta, C, 'O')
66
+
end
67
+
function LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCOO{T}, alpha::Number, beta::Number) where T <:BlasFloat
68
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
69
+
tB = tB in ('S', 's', 'H', 'h') ?'N': tB
70
+
mm!(tA, tB, alpha, A, B, beta, C, 'O')
67
71
end
68
72
69
73
Base.:(+)(A::ROCSparseMatrixCSR, B::ROCSparseMatrixCSR) =geam(one(eltype(A)), A, one(eltype(A)), B, 'O')
0 commit comments