@@ -3147,6 +3147,7 @@ BENCHMARK(cublas_tops<int8_t, int32_t>)->RangeMultiplier(2)->Range(8, 16384)->Co
31473147 * ! Even if `e4m3 * e4m3` scheme is used, very specific set of "C" and "D" types can be used.
31483148 * ! The "A" matrix must be transposed on Ada, Hopper, and Blackwell!
31493149 * ! For `FP4`, similarly the only consistently used configuration is `e2m1 * e2m1`.
3150+ * ! The compute type must be `CUBLAS_COMPUTE_32F` for both single- and half-precision outputs.
31503151 *
31513152 * @see "Using the cuBLASLt API" docs: https://docs.nvidia.com/cuda/cublas/#using-the-cublaslt-api
31523153 * @note To avoid including the `<cuda_fp8.h>` header, we define alternatives to `__nv_fp8_e4m3` & `__nv_fp8_e5m2`.
@@ -3170,6 +3171,15 @@ cudaDataType_t to_cuda_data_type() {
31703171 throw std::invalid_argument (" Unknown CUDA type" );
31713172}
31723173
3174+ template <typename scalar_type_>
3175+ cublasComputeType_t to_cublas_compute_type () {
3176+ if constexpr (std::is_same_v<scalar_type_, double >) return CUBLAS_COMPUTE_64F;
3177+ if constexpr (std::is_same_v<scalar_type_, float >) return CUBLAS_COMPUTE_32F;
3178+ if constexpr (std::is_same_v<scalar_type_, __half>) return CUBLAS_COMPUTE_16F;
3179+ if constexpr (std::is_same_v<scalar_type_, std::int32_t >) return CUBLAS_COMPUTE_32I;
3180+ throw std::invalid_argument (" Unknown CUDA type" );
3181+ }
3182+
31733183template <typename input_scalar_type_, typename output_scalar_type_ = input_scalar_type_>
31743184static void cublaslt_tops (bm::State &state) {
31753185
@@ -3179,7 +3189,7 @@ static void cublaslt_tops(bm::State &state) {
31793189 // requirements listed in Tensor Core Usage (i.e. pointers and matrix dimension must support
31803190 // 16-byte alignment).
31813191 if (n % 16 != 0 ) throw std::invalid_argument (" Tensor side not properly aligned." );
3182- int lda = static_cast <int >(n), ldb = static_cast <int >(n), ldc = static_cast <int >(n);
3192+ int lda = static_cast <int >(n), ldb = static_cast <int >(n), ldc = static_cast <int >(n), ldd = static_cast < int >(n) ;
31833193
31843194 // "A" must be transposed and "B" non-transposed (The "TN" format) on Ada (compute capability 8.9),
31853195 // Hopper (compute capability 9.0), and Blackwell GeForce (compute capability 12.x) GPUs.
@@ -3208,7 +3218,8 @@ static void cublaslt_tops(bm::State &state) {
32083218
32093219 // Create the matmul descriptor.
32103220 cublasLtMatmulDesc_t descriptor = nullptr ;
3211- cublas_check (cublasLtMatmulDescCreate (&descriptor, CUBLAS_COMPUTE_32F, to_cuda_data_type<output_scalar_type_>()));
3221+ cublas_check (cublasLtMatmulDescCreate (&descriptor, to_cublas_compute_type<float >(),
3222+ to_cuda_data_type<output_scalar_type_>()));
32123223 cublas_check (
32133224 cublasLtMatmulDescSetAttribute (descriptor, CUBLASLT_MATMUL_DESC_TRANSA, &a_transpose, sizeof (a_transpose)));
32143225 cublas_check (
@@ -3230,7 +3241,7 @@ static void cublaslt_tops(bm::State &state) {
32303241 cublas_check (cublasLtMatrixLayoutCreate (&a_descriptor, to_cuda_data_type<input_scalar_type_>(), n, n, lda));
32313242 cublas_check (cublasLtMatrixLayoutCreate (&b_descriptor, to_cuda_data_type<input_scalar_type_>(), n, n, ldb));
32323243 cublas_check (cublasLtMatrixLayoutCreate (&c_descriptor, to_cuda_data_type<output_scalar_type_>(), n, n, ldc));
3233- cublas_check (cublasLtMatrixLayoutCreate (&d_descriptor, to_cuda_data_type<output_scalar_type_>(), n, n, ldc ));
3244+ cublas_check (cublasLtMatrixLayoutCreate (&d_descriptor, to_cuda_data_type<output_scalar_type_>(), n, n, ldd ));
32343245
32353246 // Create a preference handle and set workspace limit (0 in this example).
32363247 cublasLtMatmulPreference_t preference = nullptr ;
@@ -3280,7 +3291,6 @@ static void cublaslt_tops(bm::State &state) {
32803291}
32813292
32823293BENCHMARK (cublaslt_tops<fp8_e4m3_t , float >)->RangeMultiplier(2 )->Range(256 , 16384 )->Complexity(benchmark::oNCubed);
3283- BENCHMARK (cublaslt_tops<fp8_e4m3_t , __half>)->RangeMultiplier(2 )->Range(256 , 16384 )->Complexity(benchmark::oNCubed);
32843294
32853295/* *
32863296 * Here are the numbers one can expect on a Nvidia H200 GPUs:
0 commit comments