Skip to content

Commit 46e4af5

Browse files
sherry-1001syuoni
andauthored
[TRTLLM-9831][perf] Enable 2CTA with autotune for CuteDSL MoE and Grouped GEMM optimizations (#10201)
Signed-off-by: zhichen jiang <[email protected]> Signed-off-by: Enwei Zhu <[email protected]> Co-authored-by: Enwei Zhu <[email protected]>
1 parent fe12fae commit 46e4af5

15 files changed

+1308
-502
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -639,10 +639,13 @@ class TacticsCapture:
639639
- Current replay state (which config and call index)
640640
"""
641641

642+
runner_tactic_comb_checkers: List[Callable] = []
643+
642644
def __init__(self, autotuner):
643645
# State for captured contexts
644646
self._captured_contexts: List[Dict[str, Any]] = []
645-
self._configurations = None
647+
self._context_tactics_lists: Optional[List[List[Tuple[int,
648+
Any]]]] = None
646649
# State for replay mode
647650
self._replay_runner_tactic_list: Optional[List[Tuple[int,
648651
int]]] = None
@@ -654,10 +657,13 @@ def __iter__(self):
654657
For single context: yields (runner, tactic)
655658
For multiple contexts: yields ((runner_ctx0, tactic_ctx0), (runner_ctx1, tactic_ctx1), ...)
656659
"""
657-
if self._configurations is None:
658-
self._configurations = self._generate_configurations()
660+
if self._context_tactics_lists is None:
661+
self._context_tactics_lists = self._generate_context_tactics_lists(
662+
)
659663

660-
for config in self._configurations:
664+
# Generate cartesian product from context and tactics where all_configrations[i][ctx] = (runner, tactic)
665+
# Such that each element in all_configrations is a replay of multiple contexts of all possible replays
666+
for config in itertools.product(*self._context_tactics_lists):
661667
# config is a tuple of (runner_idx, tactic) for each context
662668
# Convert to (runner, tactic) format for user
663669
runner_tactic_pairs = []
@@ -666,9 +672,14 @@ def __iter__(self):
666672
runner = runners[runner_idx]
667673
runner_tactic_pairs.append((runner, tactic))
668674

675+
if not all(
676+
checker(runner_tactic_pairs) for checker in
677+
self.__class__.runner_tactic_comb_checkers):
678+
continue
679+
669680
yield tuple(runner_tactic_pairs)
670681

671-
def _generate_configurations(self):
682+
def _generate_context_tactics_lists(self):
672683
"""Generate all valid tactic combinations."""
673684
if not self._captured_contexts:
674685
raise RuntimeError(
@@ -694,15 +705,17 @@ def _generate_configurations(self):
694705
tactics_lists.append((runner_idx, tactic))
695706
context_tactics_lists.append(tactics_lists)
696707

697-
# Generate cartesian product from context and tactics where all_configrations[i][ctx] = (runner, tactic)
698-
# Such that each element in all_configrations is a replay of multiple contexts of all possible replays
699-
all_configurations = list(itertools.product(*context_tactics_lists))
700-
return all_configurations
708+
return context_tactics_lists
701709

702710
def is_replaying(self) -> bool:
703711
"""Check if this TacticsCapture is currently in replay mode."""
704712
return self._replay_runner_tactic_list is not None
705713

714+
@classmethod
715+
def register_runner_tactic_comb_checker(cls, checker: Callable):
716+
cls.runner_tactic_comb_checkers.append(checker)
717+
return checker
718+
706719
def choose_one(
707720
self,
708721
custom_op: str,

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 51 additions & 266 deletions
Large diffs are not rendered by default.

tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,6 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
176176
177177
:param sf_vec_size: Scalefactor vector size (16 for NVF4, 32 for MXF4/MXF8).
178178
:type sf_vec_size: int
179-
:param acc_dtype: Data type of the accumulator (e.g., cutlass.Float32).
180-
:type acc_dtype: Type[cutlass.Numeric]
181179
:param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N).
182180
Note: use_2cta_instrs is automatically inferred from mma_tiler_mn[0]
183181
(True when M=256, False when M=128).
@@ -217,7 +215,6 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
217215
>>> # (True when M=256, False when M=128)
218216
>>> gemm = BlockScaledContiguousGatherGroupedGemmKernel(
219217
... sf_vec_size=16,
220-
... acc_dtype=cutlass.Float32,
221218
... mma_tiler_mn=(256, 128), # use_2cta_instrs=True since M=256
222219
... cluster_shape_mn=(2, 1),
223220
... vectorized_f32=True,
@@ -243,7 +240,6 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
243240
def __init__(
244241
self,
245242
sf_vec_size: int,
246-
acc_dtype: Type[cutlass.Numeric],
247243
mma_tiler_mn: Tuple[int, int],
248244
cluster_shape_mn: Tuple[int, int],
249245
vectorized_f32: bool,
@@ -274,8 +270,6 @@ def __init__(
274270
275271
:param sf_vec_size: Vector size for scale factors (16 for NVF4, 32 for MXF4/MXF8).
276272
:type sf_vec_size: int
277-
:param acc_dtype: Data type of the accumulator.
278-
:type acc_dtype: type[cutlass.Numeric]
279273
:param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction.
280274
use_2cta_instrs is automatically set based on M (True if M=256, False if M=128).
281275
:type mma_tiler_mn: Tuple[int, int]
@@ -289,7 +283,7 @@ def __init__(
289283

290284
self.sf_vec_size = sf_vec_size
291285
self.topk = topk
292-
self.acc_dtype: Type[cutlass.Numeric] = acc_dtype
286+
self.acc_dtype = cutlass.Float32
293287
self.use_2cta_instrs = mma_tiler_mn[0] == 256
294288
self.cluster_shape_mn = cluster_shape_mn
295289
# K dimension is deferred in _setup_attributes
@@ -2620,7 +2614,6 @@ def is_valid_dtypes_and_scale_factor_vec_size(
26202614
ab_dtype: Type[cutlass.Numeric],
26212615
sf_dtype: Type[cutlass.Numeric],
26222616
sf_vec_size: int,
2623-
acc_dtype: Type[cutlass.Numeric],
26242617
c_dtype: Type[cutlass.Numeric],
26252618
) -> bool:
26262619
"""
@@ -2632,8 +2625,6 @@ def is_valid_dtypes_and_scale_factor_vec_size(
26322625
:type sf_dtype: Type[cutlass.Numeric]
26332626
:param sf_vec_size: The vector size of the scale factor
26342627
:type sf_vec_size: int
2635-
:param acc_dtype: The data type of the accumulator
2636-
:type acc_dtype: Type[cutlass.Numeric]
26372628
:param c_dtype: The data type of the output tensor
26382629
:type c_dtype: Type[cutlass.Numeric]
26392630
@@ -2662,8 +2653,6 @@ def is_valid_dtypes_and_scale_factor_vec_size(
26622653
if ab_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} and sf_vec_size == 16:
26632654
is_valid = False
26642655

2665-
if acc_dtype not in {cutlass.Float32}:
2666-
is_valid = False
26672656
# Check valid c_dtype
26682657
if c_dtype not in {
26692658
cutlass.Float32,
@@ -2717,7 +2706,6 @@ def is_valid_mma_tiler_and_cluster_shape(
27172706
use_2cta_instrs: bool,
27182707
mma_tiler_mn: Tuple[int, int],
27192708
cluster_shape_mn: Tuple[int, int],
2720-
m_aligned: cutlass.Int64,
27212709
) -> bool:
27222710
"""
27232711
Check if the mma tiler and cluster shape are valid
@@ -2728,8 +2716,6 @@ def is_valid_mma_tiler_and_cluster_shape(
27282716
:type mma_tiler_mn: Tuple[int, int]
27292717
:param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster
27302718
:type cluster_shape_mn: Tuple[int, int]
2731-
:param m_aligned: The alignment requirement for group M dimension (default: 128)
2732-
:type m_aligned: cutlass.Int64
27332719
27342720
:return: True if the mma tiler and cluster shape are valid, False otherwise
27352721
:rtype: bool
@@ -2771,13 +2757,6 @@ def is_valid_mma_tiler_and_cluster_shape(
27712757
if cluster_tiler_m not in [64, 128, 256]:
27722758
is_valid = False
27732759

2774-
# Check if m_aligned is a multiple of cluster_tiler_m
2775-
# This ensures that each group's M dimension (which is a multiple of m_aligned)
2776-
# won't be split across tiles, preventing a single tile from loading data
2777-
# from multiple groups (which would access wrong B matrix data)
2778-
if m_aligned % mma_tiler_mn[0] != 0:
2779-
is_valid = False
2780-
27812760
return is_valid
27822761

27832762
@staticmethod
@@ -2838,7 +2817,6 @@ def can_implement(
28382817
ab_dtype: Type[cutlass.Numeric],
28392818
sf_dtype: Type[cutlass.Numeric],
28402819
sf_vec_size: int,
2841-
acc_dtype: Type[cutlass.Numeric],
28422820
c_dtype: Type[cutlass.Numeric],
28432821
mma_tiler_mn: Tuple[int, int],
28442822
cluster_shape_mn: Tuple[int, int],
@@ -2849,7 +2827,6 @@ def can_implement(
28492827
a_major: str,
28502828
b_major: str,
28512829
c_major: str,
2852-
m_aligned: cutlass.Int64,
28532830
) -> bool:
28542831
"""
28552832
Check if the gemm can be implemented
@@ -2860,8 +2837,6 @@ def can_implement(
28602837
:type sf_dtype: Type[cutlass.Numeric]
28612838
:param sf_vec_size: The vector size of the scale factor
28622839
:type sf_vec_size: int
2863-
:param acc_dtype: The data type of the accumulator
2864-
:type acc_dtype: Type[cutlass.Numeric]
28652840
:param c_dtype: The data type of the output tensor
28662841
:type c_dtype: Type[cutlass.Numeric]
28672842
:param use_2cta_instrs: Whether to use 2 CTA groups
@@ -2884,16 +2859,14 @@ def can_implement(
28842859
:type b_major: str
28852860
:param c_major: The major axis of the C tensor
28862861
:type c_major: str
2887-
:param m_aligned: The alignment requirement for group M dimension (default: 128)
2888-
:type m_aligned: cutlass.Int64
28892862
28902863
:return: True if the gemm can be implemented, False otherwise
28912864
:rtype: bool
28922865
"""
28932866
can_implement = True
28942867
# Skip unsupported types
28952868
if not BlockScaledContiguousGatherGroupedGemmKernel.is_valid_dtypes_and_scale_factor_vec_size(
2896-
ab_dtype, sf_dtype, sf_vec_size, acc_dtype, c_dtype
2869+
ab_dtype, sf_dtype, sf_vec_size, c_dtype
28972870
):
28982871
can_implement = False
28992872

@@ -2903,10 +2876,10 @@ def can_implement(
29032876
):
29042877
can_implement = False
29052878

2906-
use_2cta_instrs = mma_tiler_mn[0] == 256
29072879
# Skip invalid mma tile shape and cluster shape
2880+
use_2cta_instrs = mma_tiler_mn[0] == 256
29082881
if not BlockScaledContiguousGatherGroupedGemmKernel.is_valid_mma_tiler_and_cluster_shape(
2909-
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn, m_aligned
2882+
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn
29102883
):
29112884
can_implement = False
29122885
# Skip illegal problem shape for load/store alignment

0 commit comments

Comments
 (0)