Skip to content

Commit c76b428

Browse files
authored
[TRTLLM-9685] [feat] Add gather fc1 kernel by cuteDSL (#9618)
Signed-off-by: Zongfei Jing <[email protected]>
1 parent b8a5159 commit c76b428

File tree

7 files changed

+3915
-16
lines changed

7 files changed

+3915
-16
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1439,7 +1439,7 @@ repos:
14391439
additional_dependencies:
14401440
- tomli
14411441
# add ignore words list
1442-
args: ["-L", "Mor,ans,thirdparty", "--skip", "ATTRIBUTIONS-*.md,*.svg", "--skip", "security_scanning/*"]
1442+
args: ["-L", "Mor,ans,thirdparty,subtiles", "--skip", "ATTRIBUTIONS-*.md,*.svg", "--skip", "security_scanning/*"]
14431443
- repo: https://github.com/astral-sh/ruff-pre-commit
14441444
rev: v0.9.4
14451445
hooks:

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 514 additions & 2 deletions
Large diffs are not rendered by default.

tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py

Lines changed: 3029 additions & 0 deletions
Large diffs are not rendered by default.

tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py

Lines changed: 152 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@
4848

4949
import cutlass.cute as cute
5050
from cutlass.cutlass_dsl import Boolean, if_generate
51-
from cutlass.pipeline import (CooperativeGroup, PipelineAsync, PipelineOp,
52-
PipelineState)
51+
from cutlass.pipeline import (Agent, CooperativeGroup, PipelineAsync,
52+
PipelineOp, PipelineState, agent_sync)
5353

5454

5555
def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None):
@@ -374,3 +374,153 @@ def then_body():
374374
self.producer_acquire(state)
375375

376376
if_generate(is_leader_cta, then_body)
377+
378+
379+
@dataclass(frozen=True)
380+
class PipelineCpAsyncUmma(PipelineAsync):
381+
"""
382+
PipelineCpAsyncUmma is used for LDGSTS (CpAsync) producers and UMMA consumers.
383+
384+
This pipeline is specifically designed for scenarios where:
385+
- Producers use LDGSTS instructions (cp.async) to load data from global to shared memory
386+
- Consumers are UMMA warps that perform MMA operations using the loaded data
387+
388+
Key differences from PipelineAsyncUmma:
389+
- Suitable for gather/permutation operations during load
390+
- Used in this kernel for A and SFA matrices with token-based gather addressing
391+
"""
392+
393+
cta_group: cute.nvgpu.tcgen05.CtaGroup
394+
395+
@staticmethod
396+
def _compute_leading_cta_rank(cta_v_size):
397+
"""
398+
Computes the leading CTA rank.
399+
"""
400+
cta_rank_in_cluster = cute.arch.make_warp_uniform(
401+
cute.arch.block_idx_in_cluster())
402+
return cta_rank_in_cluster // cta_v_size * cta_v_size
403+
404+
@staticmethod
405+
def _compute_is_leader_cta(cta_layout_vmnk: cute.Layout):
406+
"""
407+
Computes leader threadblocks for 2CTA kernels. For 1CTA, all threadblocks are leaders.
408+
"""
409+
bidx, bidy, _ = cute.arch.block_idx()
410+
mma_coord_vmnk = (
411+
bidx % cute.size(cta_layout_vmnk, mode=[0]),
412+
bidx // cute.size(cta_layout_vmnk, mode=[0]),
413+
bidy,
414+
None,
415+
)
416+
return mma_coord_vmnk[0] == 0
417+
418+
@staticmethod
419+
def _compute_peer_cta_mask(cta_layout_vmnk: cute.Layout):
420+
"""
421+
Computes a mask for signaling arrivals to multicasting threadblocks.
422+
"""
423+
cta_rank_in_cluster = cute.arch.make_warp_uniform(
424+
cute.arch.block_idx_in_cluster())
425+
cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(
426+
cta_rank_in_cluster)
427+
mask_self = cute.nvgpu.cpasync.create_tma_multicast_mask(
428+
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=0)
429+
block_in_cluster_coord_vmnk_peer = (
430+
cta_in_cluster_coord_vmnk[0] ^ 1,
431+
*cta_in_cluster_coord_vmnk[1:],
432+
)
433+
mask_peer = cute.nvgpu.cpasync.create_tma_multicast_mask(
434+
cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=0)
435+
return mask_self | mask_peer
436+
437+
@staticmethod
438+
def create(
439+
*,
440+
num_stages: int,
441+
producer_group: CooperativeGroup,
442+
consumer_group: CooperativeGroup,
443+
barrier_storage: cute.Pointer = None,
444+
cta_layout_vmnk: Optional[cute.Layout] = None,
445+
defer_sync: bool = False,
446+
enable_cp_async: bool = False,
447+
):
448+
"""Creates and initializes a new PipelineCpAsyncUmma instance.
449+
450+
:param num_stages: Number of buffer stages for this pipeline
451+
:type num_stages: int
452+
:param producer_group: CooperativeGroup for the producer agent
453+
:type producer_group: CooperativeGroup
454+
:param consumer_group: CooperativeGroup for the consumer agent
455+
:type consumer_group: CooperativeGroup
456+
:param barrier_storage: Pointer to the shared memory address for this pipeline's mbarriers
457+
:type barrier_storage: cute.Pointer, optional
458+
:param cta_layout_vmnk: Layout of the cluster shape
459+
:type cta_layout_vmnk: cute.Layout, optional
460+
:param defer_sync: Whether to defer the sync
461+
:type defer_sync: bool, optional
462+
:param enable_cp_async: Whether to enable cp.async instructions
463+
:type enable_cp_async: bool, optional
464+
:raises ValueError: If barrier_storage is not a cute.Pointer instance
465+
:return: A new PipelineCpAsyncUmma instance configured with the provided parameters
466+
:rtype: PipelineCpAsyncUmma
467+
"""
468+
if not isinstance(barrier_storage, cute.Pointer):
469+
raise ValueError(
470+
f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
471+
)
472+
473+
producer_type = PipelineOp.AsyncLoad if enable_cp_async else PipelineOp.AsyncThread
474+
consumer_type = PipelineOp.TCGen05Mma
475+
476+
producer = (producer_type, producer_group)
477+
consumer = (consumer_type, consumer_group)
478+
479+
sync_object_full = PipelineAsync._make_sync_object(
480+
barrier_storage.align(min_align=8),
481+
num_stages,
482+
producer,
483+
)
484+
sync_object_empty = PipelineAsync._make_sync_object(
485+
barrier_storage.align(min_align=8) + num_stages, num_stages,
486+
consumer)
487+
488+
cta_v_size = cute.size(cta_layout_vmnk,
489+
mode=[0]) if cta_layout_vmnk is not None else 1
490+
cta_group = (cute.nvgpu.tcgen05.CtaGroup.ONE if cta_layout_vmnk is None
491+
or cute.size(cta_layout_vmnk, mode=[0]) == 1 else
492+
cute.nvgpu.tcgen05.CtaGroup.TWO)
493+
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1:
494+
# No mcast mask if we're not using 2CTA tcgen05 MMA
495+
producer_mask = None
496+
consumer_mask = None
497+
else:
498+
# If we're using 2CTA UMMAs, producer will arrive the mbar on leading CTA
499+
# We need to get the target cta_rank
500+
producer_mask = PipelineCpAsyncUmma._compute_leading_cta_rank(
501+
cta_v_size)
502+
# consumer needs to get the mask to signal
503+
consumer_mask = PipelineCpAsyncUmma._compute_peer_cta_mask(
504+
cta_layout_vmnk)
505+
506+
if not defer_sync:
507+
if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
508+
agent_sync(Agent.ThreadBlock)
509+
else:
510+
agent_sync(Agent.ThreadBlockCluster, is_relaxed=True)
511+
512+
return PipelineCpAsyncUmma(
513+
sync_object_full,
514+
sync_object_empty,
515+
num_stages,
516+
producer_mask,
517+
consumer_mask,
518+
cta_group,
519+
)
520+
521+
def consumer_release(self, state: PipelineState):
522+
"""
523+
UMMA consumer release buffer empty, cta_group needs to be provided.
524+
"""
525+
self.sync_object_empty.arrive(state.index, self.consumer_mask,
526+
self.cta_group)

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -273,22 +273,16 @@ def run_moe_nvfp4(
273273
local_num_experts=self.expert_size_per_partition,
274274
tile_tokens_dim=tile_size,
275275
)
276-
x, x_sf = torch.ops.trtllm.moe_permute(
277-
input=x.view(torch.float4_e2m1fn_x2),
278-
input_sf=x_sf,
279-
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
280-
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
281-
num_non_exiting_tiles=num_non_exiting_tiles,
282-
tile_tokens_dim=tile_size,
283-
top_k=self.routing_method.experts_per_token,
284-
)
285-
x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_swiglu_blackwell(
276+
277+
x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell(
286278
input=x.view(torch.float4_e2m1fn_x2),
287279
weight=self.w3_w1_weight.view(torch.float4_e2m1fn_x2),
288280
input_scale=x_sf.view(torch.uint8),
289281
weight_scale=self.quant_scales.fc1_weight_block.view(torch.uint8),
290282
alpha=self.quant_scales.fc1_global,
291283
tile_idx_to_group_idx=tile_idx_to_expert_idx,
284+
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
285+
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
292286
num_non_exiting_tiles=num_non_exiting_tiles,
293287
global_sf=self.fc2_input_scale,
294288
num_experts=self.num_slots,

tensorrt_llm/_torch/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,15 @@ def fp4_scale_infer_shape(input_shapes: List[List[int]]):
291291
return scale_shape * 2
292292

293293

294+
def fp4_unswizzled_scale_infer_shape(input_shapes: List[List[int]]):
295+
"""Calculate the dimensions of the fp4 scale tensor.
296+
"""
297+
out_shape, scale_shape = fp4_utils.get_fp4_shape(input_shapes[0],
298+
sf_vec_size=16,
299+
is_swizzled_layout=False)
300+
return scale_shape * 2
301+
302+
294303
_enable_piecewise_cuda_graph = True
295304

296305

0 commit comments

Comments
 (0)