Skip to content

Commit 7a83380

Browse files
authored
CUSPARSE SpGEMM: Support algorithms 2 and 3 (#2769)
1 parent a5b5bd7 commit 7a83380

File tree

4 files changed

+183
-89
lines changed

4 files changed

+183
-89
lines changed

lib/cusparse/generic.jl

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,9 @@ function gemm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSpars
427427
throw(ArgumentError("Sparse matrix-matrix multiplication only supports transa ($transa) = 'N' and transb ($transb) = 'N'"))
428428
end
429429

430+
alpha_ref = Ref{T}(alpha)
431+
beta_ref = Ref{T}(beta)
432+
430433
descA = CuSparseMatrixDescriptor(A, index)
431434
descB = CuSparseMatrixDescriptor(B, index)
432435
descC = CuSparseMatrixDescriptor(C, index)
@@ -440,30 +443,61 @@ function gemm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSpars
440443
function buffer1Size()
441444
out = Ref{Csize_t}(0)
442445
cusparseSpGEMM_workEstimation(
443-
handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta),
446+
handle(), transa, transb, alpha_ref, descA, descB, beta_ref,
444447
descC, T, algo, spgemm_desc, out, CU_NULL)
445448
return out[]
446449
end
447450
with_workspace(buffer1, buffer1Size) do buffer
448451
out = Ref{Csize_t}(sizeof(buffer))
449452
cusparseSpGEMM_workEstimation(
450-
handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta),
453+
handle(), transa, transb, alpha_ref, descA, descB, beta_ref,
451454
descC, T, algo, spgemm_desc, out, buffer)
452455
end
453456

454457
# compute the structure of the output matrix and its values in a temporary buffer
455-
function buffer2Size()
456-
out = Ref{Csize_t}(0)
457-
cusparseSpGEMM_compute(
458-
handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta),
459-
descC, T, algo, spgemm_desc, out, CU_NULL)
460-
return out[]
461-
end
462-
with_workspace(buffer2, buffer2Size) do buffer
463-
out = Ref{Csize_t}(sizeof(buffer))
464-
cusparseSpGEMM_compute(
465-
handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta),
466-
descC, T, algo, spgemm_desc, out, buffer)
458+
if algo == CUSPARSE_SPGEMM_DEFAULT || algo == CUSPARSE_SPGEMM_ALG1
459+
function buffer2Size()
460+
out = Ref{Csize_t}(0)
461+
cusparseSpGEMM_compute(
462+
handle(), transa, transb, alpha_ref, descA, descB, beta_ref,
463+
descC, T, algo, spgemm_desc, out, CU_NULL)
464+
return out[]
465+
end
466+
with_workspace(buffer2, buffer2Size) do buffer
467+
out = Ref{Csize_t}(sizeof(buffer))
468+
cusparseSpGEMM_compute(
469+
handle(), transa, transb, alpha_ref, descA, descB, beta_ref,
470+
descC, T, algo, spgemm_desc, out, buffer)
471+
end
472+
elseif algo == CUSPARSE_SPGEMM_ALG2 || algo == CUSPARSE_SPGEMM_ALG3
473+
chunk_fraction = Cfloat(0.2) # as per NVIDIA example (make it configurable?)
474+
function buffer3Size()
475+
out = Ref{Csize_t}(0)
476+
cusparseSpGEMM_estimateMemory(
477+
handle(), transa, transb, alpha_ref, descA, descB, beta_ref,
478+
descC, T, algo, spgemm_desc, chunk_fraction, out, CU_NULL, 0)
479+
return out[]
480+
end
481+
with_workspace(buffer3Size) do buffer3
482+
function buffer2Size()
483+
out = Ref{Csize_t}(0)
484+
cusparseSpGEMM_estimateMemory(
485+
handle(), transa, transb, alpha_ref, descA, descB, beta_ref,
486+
descC, T, algo, spgemm_desc, chunk_fraction, sizeof(buffer3),
487+
buffer3, out)
488+
return out[]
489+
end
490+
with_workspace(buffer2, buffer2Size) do buffer
491+
unsafe_free!(buffer3)
492+
493+
out = Ref{Csize_t}(sizeof(buffer))
494+
cusparseSpGEMM_compute(
495+
handle(), transa, transb, alpha_ref, descA, descB, beta_ref,
496+
descC, T, algo, spgemm_desc, out, buffer)
497+
end
498+
end
499+
else
500+
throw(ArgumentError("Unsupported SpGEMM algorithm: $algo"))
467501
end
468502
CUDA.unsafe_free!(buffer1)
469503

@@ -491,8 +525,8 @@ function gemm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSpars
491525
end
492526

493527
# copy the offsets, column indices, and values to the output matrix
494-
cusparseSpGEMM_copy(handle(), transa, transb, Ref{T}(alpha), descA, descB,
495-
Ref{T}(beta), descC, T, algo, spgemm_desc)
528+
cusparseSpGEMM_copy(handle(), transa, transb, alpha_ref, descA, descB,
529+
beta_ref, descC, T, algo, spgemm_desc)
496530
CUDA.unsafe_free!(buffer2)
497531
end
498532

0 commit comments

Comments
 (0)