@@ -427,6 +427,9 @@ function gemm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSpars
427
427
throw (ArgumentError (" Sparse matrix-matrix multiplication only supports transa ($transa ) = 'N' and transb ($transb ) = 'N'" ))
428
428
end
429
429
430
+ alpha_ref = Ref {T} (alpha)
431
+ beta_ref = Ref {T} (beta)
432
+
430
433
descA = CuSparseMatrixDescriptor (A, index)
431
434
descB = CuSparseMatrixDescriptor (B, index)
432
435
descC = CuSparseMatrixDescriptor (C, index)
@@ -440,30 +443,61 @@ function gemm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSpars
440
443
function buffer1Size ()
441
444
out = Ref {Csize_t} (0 )
442
445
cusparseSpGEMM_workEstimation (
443
- handle (), transa, transb, Ref {T} (alpha) , descA, descB, Ref {T} (beta) ,
446
+ handle (), transa, transb, alpha_ref , descA, descB, beta_ref ,
444
447
descC, T, algo, spgemm_desc, out, CU_NULL)
445
448
return out[]
446
449
end
447
450
with_workspace (buffer1, buffer1Size) do buffer
448
451
out = Ref {Csize_t} (sizeof (buffer))
449
452
cusparseSpGEMM_workEstimation (
450
- handle (), transa, transb, Ref {T} (alpha) , descA, descB, Ref {T} (beta) ,
453
+ handle (), transa, transb, alpha_ref , descA, descB, beta_ref ,
451
454
descC, T, algo, spgemm_desc, out, buffer)
452
455
end
453
456
454
457
# compute the structure of the output matrix and its values in a temporary buffer
455
- function buffer2Size ()
456
- out = Ref {Csize_t} (0 )
457
- cusparseSpGEMM_compute (
458
- handle (), transa, transb, Ref {T} (alpha), descA, descB, Ref {T} (beta),
459
- descC, T, algo, spgemm_desc, out, CU_NULL)
460
- return out[]
461
- end
462
- with_workspace (buffer2, buffer2Size) do buffer
463
- out = Ref {Csize_t} (sizeof (buffer))
464
- cusparseSpGEMM_compute (
465
- handle (), transa, transb, Ref {T} (alpha), descA, descB, Ref {T} (beta),
466
- descC, T, algo, spgemm_desc, out, buffer)
458
+ if algo == CUSPARSE_SPGEMM_DEFAULT || algo == CUSPARSE_SPGEMM_ALG1
459
+ function buffer2Size ()
460
+ out = Ref {Csize_t} (0 )
461
+ cusparseSpGEMM_compute (
462
+ handle (), transa, transb, alpha_ref, descA, descB, beta_ref,
463
+ descC, T, algo, spgemm_desc, out, CU_NULL)
464
+ return out[]
465
+ end
466
+ with_workspace (buffer2, buffer2Size) do buffer
467
+ out = Ref {Csize_t} (sizeof (buffer))
468
+ cusparseSpGEMM_compute (
469
+ handle (), transa, transb, alpha_ref, descA, descB, beta_ref,
470
+ descC, T, algo, spgemm_desc, out, buffer)
471
+ end
472
+ elseif algo == CUSPARSE_SPGEMM_ALG2 || algo == CUSPARSE_SPGEMM_ALG3
473
+ chunk_fraction = Cfloat (0.2 ) # as per NVIDIA example (make it configurable?)
474
+ function buffer3Size ()
475
+ out = Ref {Csize_t} (0 )
476
+ cusparseSpGEMM_estimateMemory (
477
+ handle (), transa, transb, alpha_ref, descA, descB, beta_ref,
478
+ descC, T, algo, spgemm_desc, chunk_fraction, out, CU_NULL, 0 )
479
+ return out[]
480
+ end
481
+ with_workspace (buffer3Size) do buffer3
482
+ function buffer2Size ()
483
+ out = Ref {Csize_t} (0 )
484
+ cusparseSpGEMM_estimateMemory (
485
+ handle (), transa, transb, alpha_ref, descA, descB, beta_ref,
486
+ descC, T, algo, spgemm_desc, chunk_fraction, sizeof (buffer3),
487
+ buffer3, out)
488
+ return out[]
489
+ end
490
+ with_workspace (buffer2, buffer2Size) do buffer
491
+ unsafe_free! (buffer3)
492
+
493
+ out = Ref {Csize_t} (sizeof (buffer))
494
+ cusparseSpGEMM_compute (
495
+ handle (), transa, transb, alpha_ref, descA, descB, beta_ref,
496
+ descC, T, algo, spgemm_desc, out, buffer)
497
+ end
498
+ end
499
+ else
500
+ throw (ArgumentError (" Unsupported SpGEMM algorithm: $algo " ))
467
501
end
468
502
CUDA. unsafe_free! (buffer1)
469
503
@@ -491,8 +525,8 @@ function gemm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSpars
491
525
end
492
526
493
527
# copy the offsets, column indices, and values to the output matrix
494
- cusparseSpGEMM_copy (handle (), transa, transb, Ref {T} (alpha) , descA, descB,
495
- Ref {T} (beta) , descC, T, algo, spgemm_desc)
528
+ cusparseSpGEMM_copy (handle (), transa, transb, alpha_ref , descA, descB,
529
+ beta_ref , descC, T, algo, spgemm_desc)
496
530
CUDA. unsafe_free! (buffer2)
497
531
end
498
532
0 commit comments