Skip to content

Commit 0b279f4

Browse files
authored
[https://nvbugs/5456493][feat] Add fp8 bmm on sm120 (#9687)
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
1 parent 4e55b83 commit 0b279f4

File tree

7 files changed

+301
-185
lines changed

7 files changed

+301
-185
lines changed

cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuh

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@
2727
#include <string>
2828
#include <vector>
2929

30-
#include "6kd_blockwise_gemm/sm120_fp8_gemm_1d2d.cuh"
3130
#include "ada_blockwise_gemm/sm89_fp8_gemm_1d1d.cuh"
3231
#include "fp8_blockscale_mma_utils.cuh"
3332
#include "fp8_blockscale_tma_utils.cuh"
33+
#include "sm120_blockwise_gemm/sm120_fp8_gemm_1d1d.cuh"
3434
#include "tensorrt_llm/common/config.h"
3535
#include "tensorrt_llm/common/cudaUtils.h"
3636
#include "tensorrt_llm/common/logger.h"
@@ -1653,14 +1653,28 @@ void gemm_dispatch_sm120(void* mat_a, void* mat_b, void* mat_d, float* scales_a,
16531653
using Params = typename GemmKernel::Params;
16541654
using Arguments = typename GemmKernel::Arguments;
16551655
using ProblemShape = typename GemmKernel::ProblemShape;
1656+
ProblemShape problem_shape = make_shape((int) shape_m, (int) shape_n, (int) shape_k, 1);
16561657

16571658
auto ptr_A = reinterpret_cast<ElementInput*>(mat_a);
16581659
auto ptr_B = reinterpret_cast<ElementInput*>(mat_b);
16591660
auto ptr_SFA = reinterpret_cast<ElementBlockScale*>(scales_a);
16601661
auto ptr_SFB = reinterpret_cast<ElementBlockScale*>(scales_b);
16611662
auto ptr_D = reinterpret_cast<ElementOutput*>(mat_d);
1662-
Arguments args = {ptr_A, ptr_B, ptr_SFA, ptr_SFB, ptr_D};
1663-
ProblemShape problem_shape = make_shape((int) shape_m, (int) shape_n, (int) shape_k, 1);
1663+
1664+
int32_t ld_a = shape_k;
1665+
int32_t stride_a = shape_m * shape_k;
1666+
int32_t ld_b = shape_k;
1667+
int32_t stride_b = shape_n * shape_k;
1668+
int32_t ld_d = shape_n;
1669+
int32_t stride_d = shape_m * shape_n;
1670+
1671+
typename KT::StrideA dA = make_stride(ld_a, Int<1>{}, stride_a);
1672+
typename KT::StrideB dB = make_stride(ld_b, Int<1>{}, stride_b);
1673+
typename KT::StrideSFA dSFA = KT::deduce_sfa_layout(problem_shape).stride();
1674+
typename KT::StrideSFB dSFB = KT::deduce_sfb_layout(problem_shape).stride();
1675+
typename KT::StrideD dD = make_stride(ld_d, Int<1>{}, stride_d);
1676+
1677+
Arguments args = {ptr_A, dA, ptr_B, dB, ptr_SFA, dSFA, ptr_SFB, dSFB, ptr_D, dD};
16641678

16651679
Params kernel_params = GemmKernel::to_underlying_arguments(problem_shape, args);
16661680
auto kernel_ptr = &cutlass::device_kernel<GemmKernel>;
@@ -1914,6 +1928,65 @@ void strided_batch_gemm_dispatch_sm89(__nv_fp8_e4m3* mat_a, int ld_a, int stride
19141928
stride_scales_b);
19151929
}
19161930

1931+
void strided_batch_gemm_dispatch_sm120(__nv_fp8_e4m3* mat_a, int ld_a, int stride_a, __nv_fp8_e4m3* mat_b, int ld_b,
1932+
int stride_b, __nv_bfloat16* mat_d, int ld_d, int stride_d, float* scales_a, int stride_scales_a, float* scales_b,
1933+
uint32_t num_problems, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, cudaStream_t stream,
1934+
int num_device_sms = kNumDeviceSMs)
1935+
{
1936+
if (num_device_sms < 0)
1937+
{
1938+
num_device_sms = kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount();
1939+
}
1940+
using ElementInput = cute::float_e4m3_t;
1941+
using ElementOutput = cute::bfloat16_t;
1942+
using ElementAccum = float;
1943+
using ElementBlockScale = int32_t;
1944+
using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<32, 128>;
1945+
using GemmKernel = sm120_blockscaled_gemm::SM120BlockScaledKernel<KT>;
1946+
using Params = typename GemmKernel::Params;
1947+
using Arguments = typename GemmKernel::Arguments;
1948+
using ProblemShape = typename GemmKernel::ProblemShape;
1949+
ProblemShape problem_shape = make_shape((int) shape_m, (int) shape_n, (int) shape_k, (int) num_problems);
1950+
1951+
auto ptr_A = reinterpret_cast<ElementInput*>(mat_a);
1952+
auto ptr_B = reinterpret_cast<ElementInput*>(mat_b);
1953+
auto ptr_SFA = reinterpret_cast<ElementBlockScale*>(scales_a);
1954+
auto ptr_SFB = reinterpret_cast<ElementBlockScale*>(scales_b);
1955+
auto ptr_D = reinterpret_cast<ElementOutput*>(mat_d);
1956+
1957+
typename KT::StrideA dA = make_stride(ld_a, Int<1>{}, stride_a);
1958+
typename KT::StrideB dB = make_stride(ld_b, Int<1>{}, stride_b);
1959+
typename KT::StrideSFA dSFA = KT::deduce_sfa_layout(problem_shape).stride();
1960+
typename KT::StrideSFB dSFB = KT::deduce_sfb_layout(problem_shape).stride();
1961+
typename KT::StrideD dD = make_stride(ld_d, Int<1>{}, stride_d);
1962+
1963+
Arguments args = {ptr_A, dA, ptr_B, dB, ptr_SFA, dSFA, ptr_SFB, dSFB, ptr_D, dD};
1964+
1965+
Params kernel_params = GemmKernel::to_underlying_arguments(problem_shape, args);
1966+
auto kernel_ptr = &cutlass::device_kernel<GemmKernel>;
1967+
1968+
cudaFuncSetAttribute(kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmKernel::kSmemSize);
1969+
auto result = cudaGetLastError();
1970+
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm120 gemm kernel cannot launch: %s", cudaGetErrorString(result));
1971+
1972+
cudaLaunchConfig_t launch_config;
1973+
cudaLaunchAttribute attrs[1];
1974+
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
1975+
attrs[0].val.programmaticStreamSerializationAllowed = 1;
1976+
1977+
launch_config.gridDim = GemmKernel::get_grid_shape(kernel_params);
1978+
launch_config.blockDim = GemmKernel::get_block_shape();
1979+
launch_config.dynamicSmemBytes = GemmKernel::kSmemSize;
1980+
launch_config.stream = stream;
1981+
launch_config.attrs = attrs;
1982+
launch_config.numAttrs = 1;
1983+
1984+
cudaLaunchKernelEx(&launch_config, kernel_ptr, kernel_params);
1985+
1986+
result = cudaGetLastError();
1987+
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm120 gemm kernel runtime error: %s", cudaGetErrorString(result));
1988+
}
1989+
19171990
void fp8_stride_batch_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_mat_a, float* scales_a, int ld_a,
19181991
int stride_a, int stride_scales_a, __nv_bfloat16 const* mat_b, __nv_fp8_e4m3* fp8_mat_b, float* scales_b, int ld_b,
19191992
int stride_b, __nv_bfloat16* mat_d, int ld_d, int stride_d, uint32_t num_problems, uint32_t shape_m,
@@ -1941,12 +2014,18 @@ void fp8_stride_batch_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_ma
19412014
}
19422015

19432016
int arch = tensorrt_llm::common::getSMVersion();
1944-
if (arch == 89 || arch == 120)
2017+
if (arch == 89)
19452018
{
19462019
strided_batch_gemm_dispatch_sm89(fp8_mat_a, ld_a, stride_a, fp8_mat_b, ld_b, stride_b, mat_d, ld_d, stride_d,
19472020
scales_a, stride_scales_a, scales_b, num_problems, shape_m, shape_n, shape_k, stream);
19482021
return;
19492022
}
2023+
if (arch == 120)
2024+
{
2025+
strided_batch_gemm_dispatch_sm120(fp8_mat_a, ld_a, stride_a, fp8_mat_b, ld_b, stride_b, mat_d, ld_d, stride_d,
2026+
scales_a, stride_scales_a, scales_b, num_problems, shape_m, shape_n, shape_k, stream);
2027+
return;
2028+
}
19502029
if (kDeepGemmEnabled)
19512030
{
19522031
strided_batch_gemm_dispatch(fp8_mat_a, ld_a, stride_a, fp8_mat_b, ld_b, stride_b, mat_d, ld_d, stride_d,

0 commit comments

Comments
 (0)