@@ -427,6 +427,9 @@ function gemm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSpars
427427 throw (ArgumentError (" Sparse matrix-matrix multiplication only supports transa ($transa ) = 'N' and transb ($transb ) = 'N'" ))
428428 end
429429
430+ alpha_ref = Ref {T} (alpha)
431+ beta_ref = Ref {T} (beta)
432+
430433 descA = CuSparseMatrixDescriptor (A, index)
431434 descB = CuSparseMatrixDescriptor (B, index)
432435 descC = CuSparseMatrixDescriptor (C, index)
@@ -440,30 +443,61 @@ function gemm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSpars
440443 function buffer1Size ()
441444 out = Ref {Csize_t} (0 )
442445 cusparseSpGEMM_workEstimation (
443- handle (), transa, transb, Ref {T} (alpha) , descA, descB, Ref {T} (beta) ,
446+ handle (), transa, transb, alpha_ref , descA, descB, beta_ref ,
444447 descC, T, algo, spgemm_desc, out, CU_NULL)
445448 return out[]
446449 end
447450 with_workspace (buffer1, buffer1Size) do buffer
448451 out = Ref {Csize_t} (sizeof (buffer))
449452 cusparseSpGEMM_workEstimation (
450- handle (), transa, transb, Ref {T} (alpha) , descA, descB, Ref {T} (beta) ,
453+ handle (), transa, transb, alpha_ref , descA, descB, beta_ref ,
451454 descC, T, algo, spgemm_desc, out, buffer)
452455 end
453456
454457 # compute the structure of the output matrix and its values in a temporary buffer
455- function buffer2Size ()
456- out = Ref {Csize_t} (0 )
457- cusparseSpGEMM_compute (
458- handle (), transa, transb, Ref {T} (alpha), descA, descB, Ref {T} (beta),
459- descC, T, algo, spgemm_desc, out, CU_NULL)
460- return out[]
461- end
462- with_workspace (buffer2, buffer2Size) do buffer
463- out = Ref {Csize_t} (sizeof (buffer))
464- cusparseSpGEMM_compute (
465- handle (), transa, transb, Ref {T} (alpha), descA, descB, Ref {T} (beta),
466- descC, T, algo, spgemm_desc, out, buffer)
458+ if algo == CUSPARSE_SPGEMM_DEFAULT || algo == CUSPARSE_SPGEMM_ALG1
459+ function buffer2Size ()
460+ out = Ref {Csize_t} (0 )
461+ cusparseSpGEMM_compute (
462+ handle (), transa, transb, alpha_ref, descA, descB, beta_ref,
463+ descC, T, algo, spgemm_desc, out, CU_NULL)
464+ return out[]
465+ end
466+ with_workspace (buffer2, buffer2Size) do buffer
467+ out = Ref {Csize_t} (sizeof (buffer))
468+ cusparseSpGEMM_compute (
469+ handle (), transa, transb, alpha_ref, descA, descB, beta_ref,
470+ descC, T, algo, spgemm_desc, out, buffer)
471+ end
472+ elseif algo == CUSPARSE_SPGEMM_ALG2 || algo == CUSPARSE_SPGEMM_ALG3
473+ chunk_fraction = Cfloat (0.2 ) # as per NVIDIA example (make it configurable?)
474+ function buffer3Size ()
475+ out = Ref {Csize_t} (0 )
476+ cusparseSpGEMM_estimateMemory (
477+ handle (), transa, transb, alpha_ref, descA, descB, beta_ref,
478+ descC, T, algo, spgemm_desc, chunk_fraction, out, CU_NULL, 0 )
479+ return out[]
480+ end
481+ with_workspace (buffer3Size) do buffer3
482+ function buffer2Size ()
483+ out = Ref {Csize_t} (0 )
484+ cusparseSpGEMM_estimateMemory (
485+ handle (), transa, transb, alpha_ref, descA, descB, beta_ref,
486+ descC, T, algo, spgemm_desc, chunk_fraction, sizeof (buffer3),
487+ buffer3, out)
488+ return out[]
489+ end
490+ with_workspace (buffer2, buffer2Size) do buffer
491+ unsafe_free! (buffer3)
492+
493+ out = Ref {Csize_t} (sizeof (buffer))
494+ cusparseSpGEMM_compute (
495+ handle (), transa, transb, alpha_ref, descA, descB, beta_ref,
496+ descC, T, algo, spgemm_desc, out, buffer)
497+ end
498+ end
499+ else
500+ throw (ArgumentError (" Unsupported SpGEMM algorithm: $algo " ))
467501 end
468502 CUDA. unsafe_free! (buffer1)
469503
@@ -491,8 +525,8 @@ function gemm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSpars
491525 end
492526
493527 # copy the offsets, column indices, and values to the output matrix
494- cusparseSpGEMM_copy (handle (), transa, transb, Ref {T} (alpha) , descA, descB,
495- Ref {T} (beta) , descC, T, algo, spgemm_desc)
528+ cusparseSpGEMM_copy (handle (), transa, transb, alpha_ref , descA, descB,
529+ beta_ref , descC, T, algo, spgemm_desc)
496530 CUDA. unsafe_free! (buffer2)
497531 end
498532
0 commit comments