Skip to content

Commit 8777917

Browse files
CarstyYoufredricz-20070104
authored andcommitted
[TRTLLM-1234][feat] Add fp8 blockscaled Gemm for sm120 (NVIDIA#8844)
Signed-off-by: CarstyYou <[email protected]> Signed-off-by: FredricZ-2007 <[email protected]>
1 parent c5e8d4f commit 8777917

File tree

5 files changed

+13
-13
lines changed

5 files changed

+13
-13
lines changed

cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ set_cuda_architectures(fb_gemm_src 89 90 100f 120f)
205205
# ${INSTANTIATION_GENERATION_DIR}/fp8_rowwise_gemm)
206206

207207
add_library(fp8_blockscale_gemm_src STATIC ${FP8_BLOCKSCALE_GEMM_SRC_CU})
208-
set_cuda_architectures(fp8_blockscale_gemm_src 89 90 100f)
208+
set_cuda_architectures(fp8_blockscale_gemm_src 89 90 100f 120f)
209209

210210
set(GEMM_SWIGLU_SM90_SRC_CU
211211
${CMAKE_CURRENT_SOURCE_DIR}/fused_gated_gemm/gemm_swiglu_e4m3.cu)

cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuh

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1622,16 +1622,15 @@ void gemm_dispatch_sm89(void* mat_a, void* mat_b, void* mat_d, float* scales_a,
16221622
dim3 grid = dim3(grid_m, grid_n, grid_k);
16231623
dim3 block = dim3(kThreadCount, 1, 1);
16241624

1625-
if (kSmemSize > (48 << 10))
1626-
{
1627-
cudaFuncSetAttribute(ada_blockwise_gemm::sm89_fp8_gemm_1d1d_impl<GemmKernel>,
1628-
cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize);
1629-
auto result = cudaGetLastError();
1630-
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm89 gemm kernel cannot launch: %s", cudaGetErrorString(result));
1631-
}
1625+
auto result = cudaFuncSetAttribute(ada_blockwise_gemm::sm89_fp8_gemm_1d1d_impl<GemmKernel>,
1626+
cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize);
1627+
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm89 gemm kernel cannot launch: %s", cudaGetErrorString(result));
16321628

16331629
ada_blockwise_gemm::sm89_fp8_gemm_1d1d_impl<GemmKernel>
16341630
<<<grid, block, kSmemSize, stream>>>(shape_m, shape_n, shape_k, mat_a, mat_b, mat_d, scales_a, scales_b);
1631+
1632+
result = cudaGetLastError();
1633+
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm89 gemm kernel runtime error: %s", cudaGetErrorString(result));
16351634
}
16361635

16371636
void fp8_gemm_run(__nv_fp8_e4m3* mat_a, int ld_a, __nv_fp8_e4m3* mat_b, int ld_b, __nv_bfloat16* mat_d, int ld_d,
@@ -1643,7 +1642,7 @@ void fp8_gemm_run(__nv_fp8_e4m3* mat_a, int ld_a, __nv_fp8_e4m3* mat_b, int ld_b
16431642
}
16441643
#ifndef PLACEHOLDER_KERNELS
16451644
int arch = tensorrt_llm::common::getSMVersion();
1646-
if (arch == 89)
1645+
if (arch == 89 || arch == 120)
16471646
{
16481647
gemm_dispatch_sm89(mat_a, mat_b, mat_d, scales_a, scales_b, shape_m, shape_n, shape_k, stream);
16491648
return;
@@ -1883,7 +1882,7 @@ void fp8_stride_batch_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_ma
18831882
}
18841883

18851884
int arch = tensorrt_llm::common::getSMVersion();
1886-
if (arch == 89)
1885+
if (arch == 89 || arch == 120)
18871886
{
18881887
strided_batch_gemm_dispatch_sm89(fp8_mat_a, ld_a, stride_a, fp8_mat_b, ld_b, stride_b, mat_d, ld_d, stride_d,
18891888
scales_a, stride_scales_a, scales_b, num_problems, shape_m, shape_n, shape_k, stream);

cpp/tensorrt_llm/thop/fp8BlockScalingGemm.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ extern torch::Tensor fp8_block_scaling_gemm(torch::Tensor const& mat1, torch::Te
209209
case 100: return fp8_block_scale_gemm_blackwell(mat1, mat2, mat1Scale, mat2Scale);
210210
case 90: return fp8_block_scaling_gemm_hopper(mat1, mat2, mat1Scale, mat2Scale);
211211
case 89: return fp8_block_scaling_gemm_ada(mat1, mat2, mat1Scale, mat2Scale);
212+
case 120: return fp8_block_scaling_gemm_ada(mat1, mat2, mat1Scale, mat2Scale);
212213
default: TORCH_CHECK(false, "Unsupported SM version for FP8 block scaling GEMM");
213214
}
214215
}

tensorrt_llm/_torch/modules/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ def fp8_block_scaling_bmm_out(
648648
mat2_dequant: Optional[torch.Tensor] = None,
649649
) -> torch.Tensor:
650650
sm_version = get_sm_version()
651-
if sm_version == 90 or sm_version == 89:
651+
if sm_version == 90 or sm_version == 89 or sm_version == 120:
652652
mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
653653
mat1)
654654

tests/unittest/_torch/thop/parallel/test_fp8_block_scale_gemm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_fp8_block_scale_deep_gemm(dtype, m, k, n):
6363

6464

6565
@pytest.mark.skipif(
66-
getSMVersion() != 100 and getSMVersion() != 89,
66+
getSMVersion() != 100 and getSMVersion() != 89 and getSMVersion() != 120,
6767
reason="The test is for Blackwell and Ada only. Current SM is %d." %
6868
getSMVersion(),
6969
)
@@ -99,7 +99,7 @@ def test_fp8_block_scale_gemm(dtype, m, k, n):
9999

100100

101101
@pytest.mark.skipif(
102-
getSMVersion() != 90 and getSMVersion() != 89,
102+
getSMVersion() != 90 and getSMVersion() != 89 and getSMVersion() != 120,
103103
reason="The test is for Hopper and Ada only. Current SM is %d." %
104104
getSMVersion(),
105105
)

0 commit comments

Comments
 (0)