@@ -549,51 +549,85 @@ end
549
549
function gemm (transa:: SparseChar , transb:: SparseChar , alpha:: Number , A:: CuSparseMatrixCSR{T} ,
550
550
B:: CuSparseMatrixCSR{T} , index:: SparseChar , algo:: cusparseSpGEMMAlg_t = CUSPARSE_SPGEMM_DEFAULT) where {T}
551
551
552
- m,k = size (A)
553
- l,n = size (B)
552
+ m, k = size (A)
553
+ l, n = size (B)
554
554
555
- (k != l) && throw (DimensionMismatch (" A must have the same the number of columns that B has as rows, but A has $k columns and B has $l columns " ))
555
+ (k != l) && throw (DimensionMismatch (" A must have the same number of columns that B has as rows, but A has $k columns and B has $l rows. " ))
556
556
! (transa == ' N' && transb == ' N' ) && throw (ArgumentError (" Sparse matrix-matrix multiplication only supports transa ($transa ) = 'N' and transb ($transb ) = 'N'" ))
557
557
558
+ alpha_ref = Ref {T} (convert (T, alpha))
559
+ beta_ref = Ref {T} (zero (T))
560
+
558
561
descA = CuSparseMatrixDescriptor (A, index)
559
562
descB = CuSparseMatrixDescriptor (B, index)
560
563
561
- rowPtr = CuVector {Cint} (undef, m+ 1 )
564
+ rowPtr = CuVector {Cint} (undef, m + 1 )
562
565
descC = CuSparseMatrixDescriptor (CuSparseMatrixCSR, rowPtr, T, Cint, m, n, index)
563
566
564
567
spgemm_desc = CuSpGEMMDescriptor ()
565
568
566
569
buffer1 = CuVector {UInt8} (undef, 0 )
567
570
buffer2 = CuVector {UInt8} (undef, 0 )
568
- GC. @preserve buffer1 buffer1 begin
571
+ GC. @preserve buffer1 buffer2 rowPtr begin
569
572
# compute an upper bound of the memory required for the intermediate products.
570
573
function buffer1Size ()
571
574
out = Ref {Csize_t} (0 )
572
575
cusparseSpGEMM_workEstimation (
573
- handle (), transa, transb, Ref {T} (alpha) , descA, descB, Ref {T} ( 0 ) ,
576
+ handle (), transa, transb, alpha_ref , descA, descB, beta_ref ,
574
577
descC, T, algo, spgemm_desc, out, CU_NULL)
575
578
return out[]
576
579
end
577
580
with_workspace (buffer1, buffer1Size) do buffer
578
581
out = Ref {Csize_t} (sizeof (buffer))
579
582
cusparseSpGEMM_workEstimation (
580
- handle (), transa, transb, Ref {T} (alpha) , descA, descB, Ref {T} ( 0 ) ,
583
+ handle (), transa, transb, alpha_ref , descA, descB, beta_ref ,
581
584
descC, T, algo, spgemm_desc, out, buffer)
582
585
end
583
586
584
587
# compute the structure of the output matrix and its values in a temporary buffer
585
- function buffer2Size ()
586
- out = Ref {Csize_t} (0 )
587
- cusparseSpGEMM_compute (
588
- handle (), transa, transb, Ref {T} (alpha), descA, descB, Ref {T} (0 ),
589
- descC, T, algo, spgemm_desc, out, CU_NULL)
590
- return out[]
591
- end
592
- with_workspace (buffer2, buffer2Size) do buffer
593
- out = Ref {Csize_t} (sizeof (buffer))
594
- cusparseSpGEMM_compute (
595
- handle (), transa, transb, Ref {T} (alpha), descA, descB, Ref {T} (0 ),
596
- descC, T, algo, spgemm_desc, out, buffer)
588
+ if algo == CUSPARSE_SPGEMM_DEFAULT || algo == CUSPARSE_SPGEMM_ALG1
589
+ function buffer2Size ()
590
+ out = Ref {Csize_t} (0 )
591
+ cusparseSpGEMM_compute (
592
+ handle (), transa, transb, alpha_ref, descA, descB, beta_ref,
593
+ descC, T, algo, spgemm_desc, out, CU_NULL)
594
+ return out[]
595
+ end
596
+ with_workspace (buffer2, buffer2Size) do buffer
597
+ out = Ref {Csize_t} (sizeof (buffer))
598
+ cusparseSpGEMM_compute (
599
+ handle (), transa, transb, alpha_ref, descA, descB, beta_ref,
600
+ descC, T, algo, spgemm_desc, out, buffer)
601
+ end
602
+ elseif algo == CUSPARSE_SPGEMM_ALG2 || algo == CUSPARSE_SPGEMM_ALG3
603
+ chunk_fraction = Cfloat (0.2 ) # as per NVIDIA example (make it configurable?)
604
+ function buffer3Size ()
605
+ out = Ref {Csize_t} (0 )
606
+ cusparseSpGEMM_estimateMemory (
607
+ handle (), transa, transb, alpha_ref, descA, descB, beta_ref,
608
+ descC, T, algo, spgemm_desc, chunk_fraction, out, CU_NULL, 0 )
609
+ return out[]
610
+ end
611
+ with_workspace (buffer3Size) do buffer3
612
+ function buffer2Size ()
613
+ out = Ref {Csize_t} (0 )
614
+ cusparseSpGEMM_estimateMemory (
615
+ handle (), transa, transb, alpha_ref, descA, descB, beta_ref,
616
+ descC, T, algo, spgemm_desc, chunk_fraction, sizeof (buffer3),
617
+ buffer3, out)
618
+ return out[]
619
+ end
620
+ with_workspace (buffer2, buffer2Size) do buffer
621
+ unsafe_free! (buffer3)
622
+
623
+ out = Ref {Csize_t} (sizeof (buffer))
624
+ cusparseSpGEMM_compute (
625
+ handle (), transa, transb, alpha_ref, descA, descB, beta_ref,
626
+ descC, T, algo, spgemm_desc, out, buffer)
627
+ end
628
+ end
629
+ else
630
+ throw (ArgumentError (" Unsupported SpGEMM algorithm: $algo " ))
597
631
end
598
632
CUDA. unsafe_free! (buffer1)
599
633
0 commit comments