Skip to content

Commit 71bc923

Browse files
CUSPARSE: Fix out-of-place SpGemm + tests (#2773)
Co-authored-by: Tim Besard <[email protected]>
1 parent 82c2074 commit 71bc923

File tree

2 files changed

+59
-25
lines changed

2 files changed

+59
-25
lines changed

lib/cusparse/generic.jl

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -549,51 +549,85 @@ end
549549
function gemm(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrixCSR{T},
550550
B::CuSparseMatrixCSR{T}, index::SparseChar, algo::cusparseSpGEMMAlg_t=CUSPARSE_SPGEMM_DEFAULT) where {T}
551551

552-
m,k = size(A)
553-
l,n = size(B)
552+
m, k = size(A)
553+
l, n = size(B)
554554

555-
(k != l) && throw(DimensionMismatch("A must have the same the number of columns that B has as rows, but A has $k columns and B has $l columns"))
555+
(k != l) && throw(DimensionMismatch("A must have the same number of columns that B has as rows, but A has $k columns and B has $l rows."))
556556
!(transa == 'N' && transb == 'N') && throw(ArgumentError("Sparse matrix-matrix multiplication only supports transa ($transa) = 'N' and transb ($transb) = 'N'"))
557557

558+
alpha_ref = Ref{T}(convert(T, alpha))
559+
beta_ref = Ref{T}(zero(T))
560+
558561
descA = CuSparseMatrixDescriptor(A, index)
559562
descB = CuSparseMatrixDescriptor(B, index)
560563

561-
rowPtr = CuVector{Cint}(undef, m+1)
564+
rowPtr = CuVector{Cint}(undef, m + 1)
562565
descC = CuSparseMatrixDescriptor(CuSparseMatrixCSR, rowPtr, T, Cint, m, n, index)
563566

564567
spgemm_desc = CuSpGEMMDescriptor()
565568

566569
buffer1 = CuVector{UInt8}(undef, 0)
567570
buffer2 = CuVector{UInt8}(undef, 0)
568-
GC.@preserve buffer1 buffer1 begin
571+
GC.@preserve buffer1 buffer2 rowPtr begin
569572
# compute an upper bound of the memory required for the intermediate products.
570573
function buffer1Size()
571574
out = Ref{Csize_t}(0)
572575
cusparseSpGEMM_workEstimation(
573-
handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(0),
576+
handle(), transa, transb, alpha_ref, descA, descB, beta_ref,
574577
descC, T, algo, spgemm_desc, out, CU_NULL)
575578
return out[]
576579
end
577580
with_workspace(buffer1, buffer1Size) do buffer
578581
out = Ref{Csize_t}(sizeof(buffer))
579582
cusparseSpGEMM_workEstimation(
580-
handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(0),
583+
handle(), transa, transb, alpha_ref, descA, descB, beta_ref,
581584
descC, T, algo, spgemm_desc, out, buffer)
582585
end
583586

584587
# compute the structure of the output matrix and its values in a temporary buffer
585-
function buffer2Size()
586-
out = Ref{Csize_t}(0)
587-
cusparseSpGEMM_compute(
588-
handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(0),
589-
descC, T, algo, spgemm_desc, out, CU_NULL)
590-
return out[]
591-
end
592-
with_workspace(buffer2, buffer2Size) do buffer
593-
out = Ref{Csize_t}(sizeof(buffer))
594-
cusparseSpGEMM_compute(
595-
handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(0),
596-
descC, T, algo, spgemm_desc, out, buffer)
588+
if algo == CUSPARSE_SPGEMM_DEFAULT || algo == CUSPARSE_SPGEMM_ALG1
589+
function buffer2Size()
590+
out = Ref{Csize_t}(0)
591+
cusparseSpGEMM_compute(
592+
handle(), transa, transb, alpha_ref, descA, descB, beta_ref,
593+
descC, T, algo, spgemm_desc, out, CU_NULL)
594+
return out[]
595+
end
596+
with_workspace(buffer2, buffer2Size) do buffer
597+
out = Ref{Csize_t}(sizeof(buffer))
598+
cusparseSpGEMM_compute(
599+
handle(), transa, transb, alpha_ref, descA, descB, beta_ref,
600+
descC, T, algo, spgemm_desc, out, buffer)
601+
end
602+
elseif algo == CUSPARSE_SPGEMM_ALG2 || algo == CUSPARSE_SPGEMM_ALG3
603+
chunk_fraction = Cfloat(0.2) # as per NVIDIA example (make it configurable?)
604+
function buffer3Size()
605+
out = Ref{Csize_t}(0)
606+
cusparseSpGEMM_estimateMemory(
607+
handle(), transa, transb, alpha_ref, descA, descB, beta_ref,
608+
descC, T, algo, spgemm_desc, chunk_fraction, out, CU_NULL, 0)
609+
return out[]
610+
end
611+
with_workspace(buffer3Size) do buffer3
612+
function buffer2Size()
613+
out = Ref{Csize_t}(0)
614+
cusparseSpGEMM_estimateMemory(
615+
handle(), transa, transb, alpha_ref, descA, descB, beta_ref,
616+
descC, T, algo, spgemm_desc, chunk_fraction, sizeof(buffer3),
617+
buffer3, out)
618+
return out[]
619+
end
620+
with_workspace(buffer2, buffer2Size) do buffer
621+
unsafe_free!(buffer3)
622+
623+
out = Ref{Csize_t}(sizeof(buffer))
624+
cusparseSpGEMM_compute(
625+
handle(), transa, transb, alpha_ref, descA, descB, beta_ref,
626+
descC, T, algo, spgemm_desc, out, buffer)
627+
end
628+
end
629+
else
630+
throw(ArgumentError("Unsupported SpGEMM algorithm: $algo"))
597631
end
598632
CUDA.unsafe_free!(buffer1)
599633

test/libraries/cusparse/generic.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -324,12 +324,12 @@ end
324324
SPGEMM_ALGOS = Dict(CuSparseMatrixCSR => [CUSPARSE.CUSPARSE_SPGEMM_DEFAULT],
325325
CuSparseMatrixCSC => [CUSPARSE.CUSPARSE_SPGEMM_DEFAULT])
326326
if CUSPARSE.version() >= v"12.0"
327-
push!(SPGEMM_ALGOS[CuSparseMatrixCSR], CUSPARSE.CUSPARSE_SPGEMM_ALG1)
328-
CUSPARSE.CUSPARSE_SPGEMM_ALG2
329-
CUSPARSE.CUSPARSE_SPGEMM_ALG3
330-
push!(SPGEMM_ALGOS[CuSparseMatrixCSC], CUSPARSE.CUSPARSE_SPGEMM_ALG1)
331-
CUSPARSE.CUSPARSE_SPGEMM_ALG2
332-
CUSPARSE.CUSPARSE_SPGEMM_ALG3
327+
append!(SPGEMM_ALGOS[CuSparseMatrixCSR], (CUSPARSE.CUSPARSE_SPGEMM_ALG1,
328+
CUSPARSE.CUSPARSE_SPGEMM_ALG2,
329+
CUSPARSE.CUSPARSE_SPGEMM_ALG3))
330+
append!(SPGEMM_ALGOS[CuSparseMatrixCSC], (CUSPARSE.CUSPARSE_SPGEMM_ALG1,
331+
CUSPARSE.CUSPARSE_SPGEMM_ALG2,
332+
CUSPARSE.CUSPARSE_SPGEMM_ALG3))
333333
end
334334
# Algorithms CUSPARSE.CUSPARSE_SPGEMM_CSR_ALG_DETERMINITIC and
335335
# CUSPARSE.CUSPARSE_SPGEMM_CSR_ALG_NONDETERMINITIC are dedicated to the cusparseSpGEMMreuse routine.

0 commit comments

Comments
 (0)