Skip to content

Conversation

@tam724
Copy link
Contributor

@tam724 tam724 commented Jan 1, 2026

The current implementation of matrix-matrix multiplication mm! of a dense and a sparse matrix (with a coo-matrix as the second argument) required a user to sort the sparse matrix before multiplication. This was tested behavior, see generic.jl or interfaces.jl.
This is because CUSPARSE only provides an API for C = a*A*B + b*C (where A is sparse), the case where B is sparse is implemented by transposing the identity to Ct = a*Bt*At + b*Ct. For csc/csr we can exchange the type to realize the transpose, but for coo we would have to resort the col- and row-indices. Sorting currently copies the whole matrix.

The requirement to sort before multiplication is somewhat unexpected, especially in the higher level interfaces mul! and * which should produce the correct result without prior sorting (or should check for correct ordering and warn/error). Currently the matrix multiplication returns normally but with a wrong result (see #2820).

However, in almost all cases we can realize the multiplication of dense * coo without sorting by flipping the 'N'/'T'/'C' argument of the sparse matrix correspondingly. See this PR.
The only case that cannot be easily realized is if the eltype(B) of the matrix is <:Complex and the transb argument == 'C'. The lazy transpose would need a "only conjugate" (no transpose) flag. This PR implements this case by materializing the conjugate of the coo-matrix entries. (Compared to sorting this already reduces memory consumption, because only a copy of the entries is required instead of a full copy).

An alternative implementation (fully inplace) could: 1. conjugate the values of C (if !iszero(b)), 2. compute the matrix product and 3. conjugate the values of C again. Without a only conjugate flag provided by CUSPARSE, I don't see a way to realize this case without additional work or implementing our own matmul (which would probably be slow).

Both of these options are suboptimal for a matmul implementation, we could also error for this specific case asking the user to supply the matrix in csc/csr format. Any opinions on that?

Closes #2820.

Also this removes mm_wrapper that would only duplicate the size checks in mm! and shortcut if isempty(B) but return a zero matrix of wrong shape.

@github-actions
Copy link
Contributor

github-actions bot commented Jan 1, 2026

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/lib/cusparse/generic.jl b/lib/cusparse/generic.jl
index 12c599b82..23dd30f3f 100644
--- a/lib/cusparse/generic.jl
+++ b/lib/cusparse/generic.jl
@@ -273,7 +273,8 @@ function mv!(transa::SparseChar, alpha::Number, A::Union{CuSparseMatrixCSC{TA},C
 end
 
 function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrix{T},
-             B::Union{DenseCuMatrix{T}, Transpose{T, <:DenseCuMatrix{T}}}, beta::Number, C::Union{DenseCuMatrix{T}, Transpose{T, <:DenseCuMatrix{T}}}, index::SparseChar, algo::cusparseSpMMAlg_t=CUSPARSE_SPMM_ALG_DEFAULT) where {T}
+        B::Union{DenseCuMatrix{T}, Transpose{T, <:DenseCuMatrix{T}}}, beta::Number, C::Union{DenseCuMatrix{T}, Transpose{T, <:DenseCuMatrix{T}}}, index::SparseChar, algo::cusparseSpMMAlg_t = CUSPARSE_SPMM_ALG_DEFAULT
+    ) where {T}
 
     (A isa CuSparseMatrixBSR) && (CUSPARSE.version() < v"12.5.1") && throw(ErrorException("This operation is not supported by the current CUDA version."))
 
@@ -321,19 +322,19 @@ function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseM
     return C
 end
 
-function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::DenseCuMatrix{T}, B::CuSparseMatrixCSC{T}, beta::Number, C::DenseCuMatrix{T}, index::SparseChar, algo::cusparseSpMMAlg_t=CUSPARSE_SPMM_ALG_DEFAULT) where T
+function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::DenseCuMatrix{T}, B::CuSparseMatrixCSC{T}, beta::Number, C::DenseCuMatrix{T}, index::SparseChar, algo::cusparseSpMMAlg_t = CUSPARSE_SPMM_ALG_DEFAULT) where {T}
     _B = CuSparseMatrixCSR{T}(B.colPtr, B.rowVal, B.nzVal, reverse(B.dims))
     mm!(transb, transa, alpha, _B, transpose(A), beta, transpose(C), index, algo)
     return C
 end
 
-function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::DenseCuMatrix{T}, B::CuSparseMatrixCSR{T}, beta::Number, C::DenseCuMatrix{T}, index::SparseChar, algo::cusparseSpMMAlg_t=CUSPARSE_SPMM_ALG_DEFAULT) where T
+function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::DenseCuMatrix{T}, B::CuSparseMatrixCSR{T}, beta::Number, C::DenseCuMatrix{T}, index::SparseChar, algo::cusparseSpMMAlg_t = CUSPARSE_SPMM_ALG_DEFAULT) where {T}
     _B = CuSparseMatrixCSC{T}(B.rowPtr, B.colVal, B.nzVal, reverse(B.dims))
     mm!(transb, transa, alpha, _B, transpose(A), beta, transpose(C), index, algo)
     return C
 end
 
-function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::DenseCuMatrix{T}, B::CuSparseMatrixCOO{T}, beta::Number, C::DenseCuMatrix{T}, index::SparseChar, algo::cusparseSpMMAlg_t=CUSPARSE_SPMM_ALG_DEFAULT) where T
+function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::DenseCuMatrix{T}, B::CuSparseMatrixCOO{T}, beta::Number, C::DenseCuMatrix{T}, index::SparseChar, algo::cusparseSpMMAlg_t = CUSPARSE_SPMM_ALG_DEFAULT) where {T}
     if T <: Real || transb == 'N' || transb == 'T'
         mm!((transb == 'N') ? 'T' : 'N', transa, alpha, B, transpose(A), beta, transpose(C), index, algo)
     else # transb == 'C'
diff --git a/lib/cusparse/helpers.jl b/lib/cusparse/helpers.jl
index 3980c5097..a24a62c8f 100644
--- a/lib/cusparse/helpers.jl
+++ b/lib/cusparse/helpers.jl
@@ -103,7 +103,7 @@ mutable struct CuDenseMatrixDescriptor
         obj
     end
 
-    CuDenseMatrixDescriptor(At::Transpose{<:Any, <:DenseCuMatrix}) = CuDenseMatrixDescriptor(parent(At), transposed=true) 
+    CuDenseMatrixDescriptor(At::Transpose{<:Any, <:DenseCuMatrix}) = CuDenseMatrixDescriptor(parent(At), transposed = true)
 
     function CuDenseMatrixDescriptor(A::DenseCuArray{T, 3}; transposed::Bool=false) where T
         desc_ref = Ref{cusparseDnMatDescr_t}()
diff --git a/lib/cusparse/interfaces.jl b/lib/cusparse/interfaces.jl
index 84bd0cea6..d5b753949 100644
--- a/lib/cusparse/interfaces.jl
+++ b/lib/cusparse/interfaces.jl
@@ -70,7 +70,7 @@ end
 function LinearAlgebra.generic_matmatmul!(C::CuMatrix{T}, tA, tB, A::CuSparseMatrix{T}, B::DenseCuMatrix{T}, alpha::Number, beta::Number) where {T <: Union{Float16, ComplexF16, BlasFloat}}
     tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
     tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
-    mm!(tA, tB, alpha, A, B, beta, C, 'O')
+    return mm!(tA, tB, alpha, A, B, beta, C, 'O')
 end
 
 for (wrapa, transa, unwrapa) in op_wrappers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Dense*Sparse matrix multiplication not working correctly for CuSparseMatrixCOO

1 participant