Skip to content

Commit b59f670

Browse files
committed
Fix: e5m2 * e5m2 unsupported
It also turns out that matrix A must be transposed for cuBLASLt.
1 parent 1240e02 commit b59f670

File tree

1 file changed

+51
-44
lines changed

1 file changed

+51
-44
lines changed

less_slow.cpp

Lines changed: 51 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3143,54 +3143,63 @@ BENCHMARK(cublas_tops<int8_t, int32_t>)->RangeMultiplier(2)->Range(8, 16384)->Co
31433143
* with different factors is also supported - a common technique used in extreme quantization both on
31443144
* CPUs and GPUs.
31453145
*
3146+
* ! Both "A" and "B" inputs can't be `e5m2`, it's either `e4m3 * e4m3` or `e5m2 * e4m3`.
3147+
* ! Even if `e4m3 * e4m3` scheme is used, very specific set of "C" and "D" types can be used.
3148+
* ! The "A" matrix must be transposed on Ada, Hopper, and Blackwell!
3149+
* ! For `FP4`, similarly the only consistently used configuration is `e2m1 * e2m1`.
3150+
*
31463151
* @see "Using the cuBLASLt API" docs: https://docs.nvidia.com/cuda/cublas/#using-the-cublaslt-api
3152+
* @note To avoid including the `<cuda_fp8.h>` header, we define alternatives to `__nv_fp8_e4m3` & `__nv_fp8_e5m2`.
31473153
*/
31483154
#include <cublasLt.h>
3149-
#include <cuda_fp8.h> // `__nv_fp8*` types
3155+
3156+
enum fp8_e4m3_t : unsigned char {};
3157+
enum fp8_e5m2_t : unsigned char {};
3158+
enum fp4_e2m1_t : unsigned char {};
3159+
static_assert(!std::is_same_v<fp8_e4m3_t, fp8_e5m2_t>);
31503160

31513161
template <typename scalar_type_>
31523162
cudaDataType_t to_cuda_data_type() {
3153-
if constexpr (std::is_same<scalar_type_, __nv_fp8_e4m3>::value) return CUDA_R_8F_E4M3;
3154-
if constexpr (std::is_same<scalar_type_, __nv_fp8_e5m2>::value) return CUDA_R_8F_E5M2;
3155-
if constexpr (std::is_same<scalar_type_, float>::value) return CUDA_R_32F;
3156-
if constexpr (std::is_same<scalar_type_, std::int8_t>::value) return CUDA_R_8I;
3157-
if constexpr (std::is_same<scalar_type_, std::uint8_t>::value) return CUDA_R_8U;
3163+
if constexpr (std::is_same_v<scalar_type_, fp8_e4m3_t>) return CUDA_R_8F_E4M3;
3164+
if constexpr (std::is_same_v<scalar_type_, fp8_e5m2_t>) return CUDA_R_8F_E5M2;
3165+
if constexpr (std::is_same_v<scalar_type_, fp4_e2m1_t>) return CUDA_R_4F_E2M1;
3166+
if constexpr (std::is_same_v<scalar_type_, float>) return CUDA_R_32F;
3167+
if constexpr (std::is_same_v<scalar_type_, __half>) return CUDA_R_16F;
3168+
if constexpr (std::is_same_v<scalar_type_, std::int8_t>) return CUDA_R_8I;
3169+
if constexpr (std::is_same_v<scalar_type_, std::uint8_t>) return CUDA_R_8U;
31583170
throw std::invalid_argument("Unknown CUDA type");
31593171
}
31603172

3161-
template <typename scalar_type_>
3162-
struct cuda_storage_type {
3163-
using scalar_type = scalar_type_;
3164-
};
3165-
3166-
template <>
3167-
struct cuda_storage_type<__nv_fp8_e4m3> {
3168-
using scalar_type = __nv_fp8_storage_t;
3169-
};
3170-
3171-
template <>
3172-
struct cuda_storage_type<__nv_fp8_e5m2> {
3173-
using scalar_type = __nv_fp8_storage_t;
3174-
};
3175-
31763173
template <typename input_scalar_type_, typename output_scalar_type_ = input_scalar_type_>
31773174
static void cublaslt_tops(bm::State &state) {
31783175

31793176
// Matrix size and leading dimensions
31803177
std::size_t n = static_cast<std::size_t>(state.range(0));
3178+
// To use tensor- or block-scaled FP8 kernels, all matrix dimensions must meet the optimal
3179+
// requirements listed in Tensor Core Usage (i.e. pointers and matrix dimension must support
3180+
// 16-byte alignment).
3181+
if (n % 16 != 0) throw std::invalid_argument("Tensor side not properly aligned.");
31813182
int lda = static_cast<int>(n), ldb = static_cast<int>(n), ldc = static_cast<int>(n);
3182-
constexpr bool same_type = std::is_same_v<input_scalar_type_, output_scalar_type_>;
3183-
cublasOperation_t a_transpose = CUBLAS_OP_N;
3184-
cublasOperation_t b_transpose = CUBLAS_OP_N;
3183+
3184+
// "A" must be transposed and "B" non-transposed (The "TN" format) on Ada (compute capability 8.9),
3185+
// Hopper (compute capability 9.0), and Blackwell GeForce (compute capability 12.x) GPUs.
3186+
cublasOperation_t const a_transpose = CUBLAS_OP_T;
3187+
cublasOperation_t const b_transpose = CUBLAS_OP_N;
31853188

31863189
// Unified memory for large matrices
3187-
using input_storage_type = typename cuda_storage_type<input_scalar_type_>::scalar_type;
3188-
unified_array<input_storage_type> a(n * n), b(n * n);
3190+
unified_array<input_scalar_type_> a(n * n), b(n * n);
31893191
unified_array<output_scalar_type_> c(n * n), d(n * n);
31903192

3191-
// With unified memory, we don't even need Thrust to initialize the data
3192-
std::iota(a.begin(), a.end(), 0);
3193-
std::iota(b.begin(), b.end(), 0);
3193+
// With unified memory, we don't even need Thrust to initialize the data,
3194+
// but we can't use `std::iota` with `enum` types, as they don't provide
3195+
// an implicit conversion operator, so let's use a `std::transform` with
3196+
// a mutable lambda state:
3197+
{
3198+
std::uint64_t counter;
3199+
auto iota_lambda = [&counter](auto) mutable { return static_cast<input_scalar_type_>(counter++); };
3200+
counter = 0, std::transform(a.begin(), a.end(), a.begin(), iota_lambda);
3201+
counter = 0, std::transform(b.begin(), b.end(), b.begin(), iota_lambda);
3202+
}
31943203
std::fill(c.begin(), c.end(), 0);
31953204
std::fill(d.begin(), d.end(), 0);
31963205

@@ -3205,17 +3214,15 @@ static void cublaslt_tops(bm::State &state) {
32053214
cublas_check(
32063215
cublasLtMatmulDescSetAttribute(descriptor, CUBLASLT_MATMUL_DESC_TRANSB, &b_transpose, sizeof(b_transpose)));
32073216

3208-
// Set per-tensor scaling attributes (using 1.0f as the default scaling factors).
3209-
float a_scale = 1.0f, b_scale = 1.0f, c_scale = 1.0f, d_scale = 1.0f;
3210-
cublas_check(
3211-
cublasLtMatmulDescSetAttribute(descriptor, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale, sizeof(a_scale)));
3212-
cublas_check(
3213-
cublasLtMatmulDescSetAttribute(descriptor, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale, sizeof(b_scale)));
3214-
cublas_check(
3215-
cublasLtMatmulDescSetAttribute(descriptor, CUBLASLT_MATMUL_DESC_C_SCALE_POINTER, &c_scale, sizeof(c_scale)));
3216-
cublas_check(
3217-
cublasLtMatmulDescSetAttribute(descriptor, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &d_scale, sizeof(d_scale)));
3218-
3217+
// We can also set the per-tensor scaling attributes, but they must be pre-allocated in the device memory!
3218+
// If not specified, or set to NULL, the scaling factor is assumed to be 1.
3219+
//
3220+
// float a_scale = 1.0f, b_scale = 1.0f, c_scale = 1.0f, d_scale = 1.0f; //! Can't be on the host like this
3221+
// cublasLtMatmulDescSetAttribute(descriptor, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale, sizeof(a_scale));
3222+
// cublasLtMatmulDescSetAttribute(descriptor, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale, sizeof(b_scale));
3223+
// cublasLtMatmulDescSetAttribute(descriptor, CUBLASLT_MATMUL_DESC_C_SCALE_POINTER, &c_scale, sizeof(c_scale));
3224+
// cublasLtMatmulDescSetAttribute(descriptor, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &d_scale, sizeof(d_scale));
3225+
//
32193226
// Create matrix layout descriptors for A, B, C, and D (output)
32203227
// https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLASLt/LtFp8Matmul/sample_cublasLt_LtFp8Matmul.cu
32213228
cublasLtMatrixLayout_t a_descriptor = nullptr, b_descriptor = nullptr, c_descriptor = nullptr,
@@ -3241,7 +3248,7 @@ static void cublaslt_tops(bm::State &state) {
32413248

32423249
// Define scaling factors (using FP32 scalars)
32433250
float alpha = 1.0f;
3244-
float beta = 0.0f;
3251+
float beta = 0.0f; // Can be non-zero only starting with 12.0
32453252

32463253
for (auto _ : state) {
32473254

@@ -3272,8 +3279,8 @@ static void cublaslt_tops(bm::State &state) {
32723279
cublas_check(cublasLtDestroy(handle));
32733280
}
32743281

3275-
BENCHMARK(cublaslt_tops<__nv_fp8_e4m3, float>)->RangeMultiplier(2)->Range(8, 16384)->Complexity(benchmark::oNCubed);
3276-
BENCHMARK(cublaslt_tops<__nv_fp8_e5m2, float>)->RangeMultiplier(2)->Range(8, 16384)->Complexity(benchmark::oNCubed);
3282+
BENCHMARK(cublaslt_tops<fp8_e4m3_t, float>)->RangeMultiplier(2)->Range(256, 16384)->Complexity(benchmark::oNCubed);
3283+
BENCHMARK(cublaslt_tops<fp8_e4m3_t, __half>)->RangeMultiplier(2)->Range(256, 16384)->Complexity(benchmark::oNCubed);
32773284

32783285
/**
32793286
* Here are the numbers one can expect on a Nvidia H200 GPUs:
@@ -3286,7 +3293,7 @@ BENCHMARK(cublaslt_tops<__nv_fp8_e5m2, float>)->RangeMultiplier(2)->Range(8, 163
32863293
* - `bf16` @b 1'000 T @b 1'047 T -
32873294
* - `f16` @b 1'000 T @b 1'056 T @b 764 T
32883295
* - `i8` & `u8` @b 2'000 T - @b 122 T
3289-
* - `e4m3` & `e5m2` @b 2'000 T - -
3296+
* - `e4m3` @b 2'000 T - @b 1'338 T
32903297
* - `b1` XOR-based - @b 79 T -
32913298
* - `b1` AND-based - @b 8'439 T -
32923299
*

0 commit comments

Comments
 (0)