@@ -722,6 +722,150 @@ bool test6() {
722
722
return !error;
723
723
}
724
724
725
+ void fgemmlt (cublasLtHandle_t ltHandle, int m, int n, int k,
726
+ const float *A, const float *B, const float *C, float *D,
727
+ float *alpha, float *beta,
728
+ int lda, int ldb, int ldc, int ldd,
729
+ cublasLtMatrixLayout_t Adesc,
730
+ cublasLtMatrixLayout_t Bdesc,
731
+ cublasLtMatrixLayout_t Cdesc,
732
+ cublasLtMatrixLayout_t Ddesc,
733
+ float *amax_d) {
734
+ cublasLtMatmulDesc_t matmulDesc = NULL ;
735
+ cublasLtMatmulDescCreate (&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
736
+
737
+ float *scale_a;
738
+ float *scale_b;
739
+ float *scale_d;
740
+ cudaMallocManaged (&scale_a, sizeof (float ));
741
+ cudaMallocManaged (&scale_b, sizeof (float ));
742
+ cudaMallocManaged (&scale_d, sizeof (float ));
743
+ scale_a[0 ] = 3 ;
744
+ scale_b[0 ] = 5 ;
745
+ scale_d[0 ] = 7 ;
746
+
747
+ cublasLtMatmulDescSetAttribute (matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scale_a, sizeof (scale_a));
748
+ cublasLtMatmulDescSetAttribute (matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scale_b, sizeof (scale_b));
749
+ cublasLtMatmulDescSetAttribute (matmulDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &scale_d, sizeof (scale_d));
750
+ cublasLtMatmulDescSetAttribute (matmulDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &amax_d, sizeof (amax_d));
751
+
752
+ cublasLtEpilogue_t ep = CUBLASLT_EPILOGUE_RELU;
753
+ cublasLtMatmulDescSetAttribute (matmulDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &ep, sizeof (ep));
754
+
755
+ cublasLtMatmul (ltHandle, matmulDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, D, Ddesc, NULL , NULL , 0 , 0 );
756
+
757
+ cudaStreamSynchronize (0 );
758
+ cublasLtMatmulDescDestroy (matmulDesc);
759
+ }
760
+
761
+ // clang-format off
762
+ // A (4*3) B (3*2)
763
+ // 6 10 14 5 4
764
+ // 7 11 15 -3 -2
765
+ // 8 12 16 1 0
766
+ // 9 13 17 p p
767
+ //
768
+ // alpha * A * B + C = alpha * A*B + C = D
769
+ // 2*3*5 6 10 14 5 4 -10000 -5000 30 14 4 -10000 -5000 -9580 -4880
770
+ // 7 11 15 -3 -2 2000 6000 17 6 2000 6000 2510 6180
771
+ // 8 12 16 1 0 3000 7000 20 8 3000 7000 3600 7240
772
+ // 9 13 17 p p 4000 8000 23 10 4000 8000 4690 8300
773
+ // scale_d * D = D
774
+ // 7 * -9580 -4880 -67060 -34160
775
+ // 2510 6180 17570 43260
776
+ // 3600 7240 25200 50680
777
+ // 4690 8300 32830 58100
778
+ // clang-format on
779
+
780
+ bool test7 () {
781
+ cublasLtHandle_t ltHandle;
782
+ cublasLtCreate (<Handle);
783
+ const constexpr int m = 4 ;
784
+ const constexpr int n = 2 ;
785
+ const constexpr int k = 3 ;
786
+ const constexpr int lda = m;
787
+ const constexpr int ldb = m;
788
+ const constexpr int ldc = m;
789
+ const constexpr int ldd = m;
790
+ void *Adev;
791
+ void *Bdev;
792
+ void *Cdev;
793
+ void *Ddev;
794
+ cudaMalloc (&Adev, lda * k * sizeof (float ));
795
+ cudaMalloc (&Bdev, ldb * n * sizeof (float ));
796
+ cudaMalloc (&Cdev, ldc * n * sizeof (float ));
797
+ cudaMalloc (&Ddev, ldd * n * sizeof (float ));
798
+
799
+ float Ahost[lda * k] = {6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 };
800
+ float Bhost[ldb * n] = {5 , -3 , 1 , 99 , 4 , -2 , 0 , 99 };
801
+ float Chost[ldc * n] = {-1000 , 2000 , 3000 , 4000 , -5000 , 6000 , 7000 , 8000 };
802
+
803
+ cudaMemcpy (Adev, Ahost, lda * k * sizeof (float ), cudaMemcpyHostToDevice);
804
+ cudaMemcpy (Bdev, Bhost, ldb * n * sizeof (float ), cudaMemcpyHostToDevice);
805
+ cudaMemcpy (Cdev, Chost, ldc * n * sizeof (float ), cudaMemcpyHostToDevice);
806
+
807
+ cublasLtMatrixLayout_t Adesc_col_major = NULL ,
808
+ Bdesc_col_major = NULL ,
809
+ Cdesc_col_major = NULL ,
810
+ Ddesc_col_major = NULL ;
811
+ cublasLtMatrixLayoutCreate (&Adesc_col_major, CUDA_R_32F, m, k, lda);
812
+ cublasLtMatrixLayoutCreate (&Bdesc_col_major, CUDA_R_32F, k, n, ldb);
813
+ cublasLtMatrixLayoutCreate (&Cdesc_col_major, CUDA_R_32F, m, n, ldc);
814
+ cublasLtMatrixLayoutCreate (&Ddesc_col_major, CUDA_R_32F, m, n, ldd);
815
+
816
+ float alpha = 2 ;
817
+ float beta = 1 ;
818
+
819
+ // Matmul
820
+
821
+ float *amax_d;
822
+ cudaMallocManaged (&amax_d, sizeof (float ));
823
+
824
+ fgemmlt (ltHandle, m, n, k, (const float *)Adev, (const float *)Bdev, (const float *)Cdev, (float *)Ddev,
825
+ &alpha, &beta, lda, ldb, ldc, ldd, Adesc_col_major, Bdesc_col_major, Cdesc_col_major, Ddesc_col_major, amax_d);
826
+ cudaStreamSynchronize (0 );
827
+
828
+ // Check result
829
+ float Dhost[ldd * n];
830
+ cudaMemcpy (Dhost, Ddev, ldd * n * sizeof (float ), cudaMemcpyDeviceToHost);
831
+
832
+ bool error = false ;
833
+ float D_ref[ldd * n] = {0 , 17570 , 25200 , 32830 , 0 , 43260 , 50680 , 58100 };
834
+ for (int i = 0 ; i < ldd * n; i++) {
835
+ if (Dhost[i] != D_ref[i]) {
836
+ error = true ;
837
+ break ;
838
+ }
839
+ }
840
+ if (*amax_d != 8300 )
841
+ error = true ;
842
+
843
+ printf (" d:\n " );
844
+ for (int i = 0 ; i < ldd * n; i++)
845
+ printf (" %f, " , Dhost[i]);
846
+ printf (" \n " );
847
+ printf (" amax_d:%f\n " , *amax_d);
848
+
849
+ if (error) {
850
+ printf (" error\n " );
851
+ } else {
852
+ printf (" success\n " );
853
+ }
854
+
855
+ cublasLtDestroy (ltHandle);
856
+ cublasLtMatrixLayoutDestroy (Adesc_col_major);
857
+ cublasLtMatrixLayoutDestroy (Bdesc_col_major);
858
+ cublasLtMatrixLayoutDestroy (Cdesc_col_major);
859
+ cublasLtMatrixLayoutDestroy (Ddesc_col_major);
860
+ cudaFree (Adev);
861
+ cudaFree (Bdev);
862
+ cudaFree (Ddev);
863
+ cudaFree (amax_d);
864
+
865
+ return !error;
866
+ }
867
+
868
+
725
869
// clang-format off
726
870
// A (4*3) B (2*3)
727
871
// 6 10 14 5 -3 1
@@ -750,5 +894,6 @@ int main() {
750
894
pass = test4 () && pass;
751
895
pass = test5 () && pass;
752
896
pass = test6 () && pass;
897
+ pass = test7 () && pass;
753
898
return pass ? 0 : 1 ;
754
899
}
0 commit comments