Skip to content

Commit f277afd

Browse files
authored
perf: Enable 128x256 tile shapes for FP4 MOE CUTLASS backend (NVIDIA#5986)
Signed-off-by: Daniel Stokes <[email protected]>
1 parent c4ee535 commit f277afd

File tree

4 files changed

+13
-6
lines changed

4 files changed

+13
-6
lines changed

cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,10 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm100(CutlassGemmConfig::Ca
383383
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1});
384384
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x128x128B,
385385
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
386-
// candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B,
387-
// MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1});
386+
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B,
387+
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1});
388+
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x256x128B,
389+
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
388390
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B,
389391
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1});
390392
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x64x128B,

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,11 +342,16 @@ using SafeBF16 = void;
342342
using EpilogueTileShapeSm100 = decltype(shape_div(TileShape{}, AtomThrShape{})); \
343343
using EpilogueTileShape = std::conditional_t<IsBlackwell, EpilogueTileShapeSm100, EpilogueTileShapeSm90>; \
344344
using EpilogueElementC = std::conditional_t<IsSM120, ElementCSafe, ElementC>; \
345+
using EpilogueTensorOp = std::conditional_t<IsBlackwell && IsBlockScaled, \
346+
cutlass::arch::OpClassBlockScaledTensorOp, cutlass::arch::OpClassTensorOp>; \
347+
using EpilogueSubTile \
348+
= std::conditional_t<Arch::kMinComputeCapability == 100 && IsFP4 && CTA_N_ == 256, /* SM100 Exactly */ \
349+
cute::Shape<cute::_128, cute::_64>, cutlass::epilogue::collective::EpilogueTileAuto>; \
345350
/* Epilogue For Default Finalize */ \
346351
using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder</**/ \
347-
Arch, cutlass::arch::OpClassTensorOp, /**/ \
352+
Arch, EpilogueTensorOp, /**/ \
348353
EpilogueTileShape, ClusterShape, /**/ \
349-
cutlass::epilogue::collective::EpilogueTileAuto, /**/ \
354+
EpilogueSubTile, /**/ \
350355
ElementAccumulator, ElementAccumulator, /**/ \
351356
EpilogueElementC, LayoutC*, AlignmentC, /**/ \
352357
ElementD, LayoutD*, AlignmentD, /**/ \

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ constexpr bool are_tile_shapes_supported_sm100()
159159
// {
160160
// return false;
161161
// }
162-
if ((TileN != 64 && TileN != 128) || TileM != 128)
162+
if ((TileN != 64 && TileN != 128 && TileN != 256) || TileM != 128)
163163
{
164164
return false;
165165
}

cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def is_gemm_op_valid_sm100(op):
359359
# TODO 128x256x256 FP4 compiles but crashes
360360
# if tile_n % 64 != 0 or tile_n < 128:
361361
# return False
362-
if tile_n not in [64, 128] or tile_m != 128:
362+
if tile_n not in [64, 128, 256] or tile_m != 128:
363363
return False
364364

365365
# Shapes for fp8 small N shapes

0 commit comments

Comments
 (0)