Skip to content

Commit 21a93fb

Browse files
authored
[TRTLLM-9992][perf] Enable PDL for CuteDSL kernels and overlap MoeOutputMemset (NVIDIA#10043)
Signed-off-by: Enwei Zhu <[email protected]>
1 parent 3f25db9 commit 21a93fb

14 files changed

+259
-183
lines changed

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2178,32 +2178,6 @@ def _(
21782178
device=input_scale.device)
21792179
return output, output_scale
21802180

2181-
class FusedMoEInputsHelper:
2182-
2183-
def __init__(self, num_experts: int, top_k: int, num_local_experts: int,
2184-
local_expert_offset: int):
2185-
self.num_experts = num_experts
2186-
self.top_k = top_k
2187-
self.num_local_experts = num_local_experts
2188-
self.local_expert_offset = local_expert_offset
2189-
2190-
def infer_shape_num_tokens(self, input_shapes: List[torch.Size]) -> int:
2191-
return input_shapes[0][0]
2192-
2193-
def inputs_pre_hook(self,
2194-
inputs: List[torch.Tensor]) -> List[torch.Tensor]:
2195-
x, x_sf, token_selected_experts, token_final_scales, *others = inputs
2196-
num_tokens = token_selected_experts.size(0)
2197-
new_token_final_scales, new_token_selected_experts = torch.randn(
2198-
num_tokens,
2199-
self.num_experts,
2200-
device=token_selected_experts.device).topk(self.top_k, dim=-1)
2201-
new_token_selected_experts = new_token_selected_experts.to(
2202-
token_selected_experts.dtype)
2203-
new_token_final_scales = new_token_final_scales.softmax(dim=-1).to(
2204-
token_final_scales.dtype)
2205-
return x, x_sf, new_token_selected_experts, new_token_final_scales, *others
2206-
22072181
class Sm100BlockScaledFusedMoERunner(TunableRunner):
22082182
tuning_config_cache = dict()
22092183

tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py

Lines changed: 14 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -35,44 +35,18 @@
3535
import cutlass.utils as utils
3636
import cutlass.utils.blackwell_helpers as sm100_utils
3737
import cutlass.utils.blockscaled_layout as blockscaled_utils
38-
from cutlass._mlir.dialects import math, nvvm
38+
from cutlass._mlir.dialects import math
3939
from cutlass.cute.nvgpu import cpasync, tcgen05
40-
from cutlass.cute.typing import Float32
41-
from cutlass.cutlass_dsl import T, dsl_user_op
4240

4341
from .custom_pipeline import PipelineCpAsyncUmma
44-
from .utils import is_power_of_2
45-
46-
47-
@dsl_user_op
48-
def fmin(
49-
a: Union[float, Float32], b: Union[float, Float32], *, nan=False, loc=None, ip=None
50-
) -> Float32:
51-
return Float32(
52-
nvvm.fmin(
53-
T.f32(),
54-
Float32(a).ir_value(loc=loc, ip=ip),
55-
Float32(b).ir_value(loc=loc, ip=ip),
56-
nan=nan,
57-
loc=loc,
58-
ip=ip,
59-
)
60-
)
61-
62-
63-
def sigmoid_f32(a: Union[float, Float32], fastmath: bool = False) -> Union[float, Float32]:
64-
"""
65-
Compute the sigmoid of the input tensor.
66-
"""
67-
return cute.arch.rcp_approx(1.0 + cute.math.exp(-a, fastmath=fastmath))
68-
69-
70-
def silu_f32(a: Union[float, Float32], fastmath: bool = False) -> Union[float, Float32]:
71-
"""
72-
Compute the silu of the input tensor.
73-
"""
74-
return a * sigmoid_f32(a, fastmath=fastmath)
75-
42+
from .utils import (
43+
TRTLLM_ENABLE_PDL,
44+
fmin,
45+
griddepcontrol_launch_dependents,
46+
griddepcontrol_wait,
47+
is_power_of_2,
48+
silu_f32,
49+
)
7650

7751
"""
7852
High-performance persistent blockscaled contiguous grouped dense GEMM with gather and SwiGLU fusion
@@ -819,6 +793,7 @@ class SharedStorage:
819793
smem=self.shared_storage.size_in_bytes(),
820794
stream=stream,
821795
min_blocks_per_mp=1,
796+
use_pdl=TRTLLM_ENABLE_PDL,
822797
)
823798
return
824799

@@ -1148,6 +1123,8 @@ def kernel(
11481123
else:
11491124
self.cta_sync_barrier.arrive_and_wait()
11501125

1126+
griddepcontrol_wait()
1127+
11511128
#
11521129
# Specialized Schedule warp
11531130
#
@@ -2282,6 +2259,8 @@ def kernel(
22822259
#
22832260
c_pipeline.producer_tail()
22842261

2262+
griddepcontrol_launch_dependents()
2263+
22852264
def epilog_tmem_copy_and_partition(
22862265
self,
22872266
tidx: cutlass.Int32,

tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,12 @@
5252
import cutlass.utils.blockscaled_layout as blockscaled_utils
5353
from cutlass.cute.nvgpu import cpasync, tcgen05
5454

55-
from .utils import is_power_of_2
55+
from .utils import (
56+
TRTLLM_ENABLE_PDL,
57+
griddepcontrol_launch_dependents,
58+
griddepcontrol_wait,
59+
is_power_of_2,
60+
)
5661

5762

5863
class Sm100BlockScaledContiguousGroupedGemmKernel:
@@ -597,6 +602,7 @@ class SharedStorage:
597602
smem=self.shared_storage.size_in_bytes(),
598603
stream=stream,
599604
min_blocks_per_mp=1,
605+
use_pdl=TRTLLM_ENABLE_PDL,
600606
)
601607
return
602608

@@ -933,6 +939,8 @@ def kernel(
933939
else:
934940
self.cta_sync_barrier.arrive_and_wait()
935941

942+
griddepcontrol_wait()
943+
936944
#
937945
# Specialized Schedule warp
938946
#
@@ -1597,6 +1605,8 @@ def kernel(
15971605
#
15981606
c_pipeline.producer_tail()
15991607

1608+
griddepcontrol_launch_dependents()
1609+
16001610
def epilog_tmem_copy_and_partition(
16011611
self,
16021612
tidx: cutlass.Int32,

tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Lines changed: 16 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,17 @@
3535
import cutlass.utils as utils
3636
import cutlass.utils.blackwell_helpers as sm100_utils
3737
import cutlass.utils.blockscaled_layout as blockscaled_utils
38-
from cutlass._mlir.dialects import llvm
3938
from cutlass.cute.nvgpu import cpasync, tcgen05
40-
from cutlass.cutlass_dsl import Int32, T, dsl_user_op
4139

42-
from .utils import is_power_of_2
40+
from .utils import (
41+
TRTLLM_ENABLE_PDL,
42+
atomic_add_func,
43+
griddepcontrol_launch_dependents,
44+
griddepcontrol_wait,
45+
is_power_of_2,
46+
vectorized_atomic_add_bf16x8,
47+
vectorized_atomic_add_fp32x2,
48+
)
4349

4450
"""
4551
High-performance persistent blockscaled contiguous grouped dense GEMM (C = alpha * (SFA * A) * (SFB * B)) example for
@@ -259,8 +265,8 @@ def hooked_PersistentTileSchedulerParams_init(
259265

260266

261267
def hooked_get_cluster_work_idx_with_fastdivmod(
262-
self, current_work_linear_idx: Int32, *, loc=None, ip=None
263-
) -> Tuple[Int32, Int32, Int32]:
268+
self, current_work_linear_idx: cutlass.Int32, *, loc=None, ip=None
269+
) -> Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]:
264270
work_iteration, work_unit_id = divmod(current_work_linear_idx, self.params.batch_fdd)
265271

266272
if self.params._raster_along_m:
@@ -287,69 +293,6 @@ def hooked_get_cluster_work_idx_with_fastdivmod(
287293
)
288294

289295

290-
# TODO(zhichenj): try to move these to NVVM wrapper or helper functions
291-
@dsl_user_op
292-
def vectorized_atomic_add_bf16x8(rOut_epi_packed, scatter_out_offset, loc=None, ip=None):
293-
llvm.inline_asm(
294-
None,
295-
[
296-
scatter_out_offset.iterator.llvm_ptr,
297-
llvm.bitcast(T.i32(), rOut_epi_packed[0, None].load().ir_value()),
298-
llvm.bitcast(T.i32(), rOut_epi_packed[1, None].load().ir_value()),
299-
llvm.bitcast(T.i32(), rOut_epi_packed[2, None].load().ir_value()),
300-
llvm.bitcast(T.i32(), rOut_epi_packed[3, None].load().ir_value()),
301-
],
302-
"red.global.v4.bf16x2.add.noftz [$0], {$1, $2, $3, $4};",
303-
"l,r,r,r,r",
304-
has_side_effects=True,
305-
)
306-
307-
308-
@dsl_user_op
309-
def vectorized_atomic_add_fp32x2(rOut_epi_packed, scatter_out_offset, loc=None, ip=None):
310-
llvm.inline_asm(
311-
None,
312-
[
313-
scatter_out_offset.iterator.llvm_ptr,
314-
rOut_epi_packed[0].ir_value(),
315-
rOut_epi_packed[1].ir_value(),
316-
],
317-
"red.global.v2.f32.add [$0], {$1, $2};",
318-
"l,f,f",
319-
has_side_effects=True,
320-
)
321-
322-
323-
@dsl_user_op
324-
def atomic_add_func(rOut_epi_packed, scatter_out_offset, loc=None, ip=None):
325-
if cutlass.const_expr(rOut_epi_packed.dtype == cutlass.Float32):
326-
llvm.inline_asm(
327-
None,
328-
[
329-
scatter_out_offset.iterator.llvm_ptr,
330-
rOut_epi_packed.ir_value(),
331-
],
332-
"red.global.add.f32 [$0], $1;",
333-
"l,f",
334-
has_side_effects=True,
335-
loc=loc,
336-
ip=ip,
337-
)
338-
elif cutlass.const_expr(rOut_epi_packed.dtype == cutlass.BFloat16):
339-
llvm.inline_asm(
340-
None,
341-
[
342-
scatter_out_offset.iterator.llvm_ptr,
343-
llvm.bitcast(T.i16(), rOut_epi_packed.ir_value()),
344-
],
345-
"red.add.noftz.bf16 [$0], $1;",
346-
"l,h",
347-
has_side_effects=True,
348-
loc=loc,
349-
ip=ip,
350-
)
351-
352-
353296
class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
354297
"""This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types
355298
and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
@@ -931,6 +874,7 @@ class SharedStorage:
931874
smem=self.shared_storage.size_in_bytes(),
932875
stream=stream,
933876
min_blocks_per_mp=1,
877+
use_pdl=TRTLLM_ENABLE_PDL,
934878
)
935879
return
936880

@@ -1286,6 +1230,8 @@ def kernel(
12861230
else:
12871231
self.cta_sync_barrier.arrive_and_wait()
12881232

1233+
griddepcontrol_wait()
1234+
12891235
#
12901236
# Specialized Schedule warp
12911237
#
@@ -1940,6 +1886,8 @@ def kernel(
19401886
self.epilog_sync_barrier.arrive_and_wait()
19411887
tmem.free(tmem_ptr)
19421888

1889+
griddepcontrol_launch_dependents()
1890+
19431891
def epilog_tmem_copy_and_partition(
19441892
self,
19451893
tidx: cutlass.Int32,

tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py

Lines changed: 14 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -35,43 +35,17 @@
3535
import cutlass.utils as utils
3636
import cutlass.utils.blackwell_helpers as sm100_utils
3737
import cutlass.utils.blockscaled_layout as blockscaled_utils
38-
from cutlass._mlir.dialects import math, nvvm
38+
from cutlass._mlir.dialects import math
3939
from cutlass.cute.nvgpu import cpasync, tcgen05
40-
from cutlass.cute.typing import Float32
41-
from cutlass.cutlass_dsl import T, dsl_user_op
42-
43-
from .utils import is_power_of_2
44-
45-
46-
@dsl_user_op
47-
def fmin(
48-
a: Union[float, Float32], b: Union[float, Float32], *, nan=False, loc=None, ip=None
49-
) -> Float32:
50-
return Float32(
51-
nvvm.fmin(
52-
T.f32(),
53-
Float32(a).ir_value(loc=loc, ip=ip),
54-
Float32(b).ir_value(loc=loc, ip=ip),
55-
nan=nan,
56-
loc=loc,
57-
ip=ip,
58-
)
59-
)
60-
61-
62-
def sigmoid_f32(a: Union[float, Float32], fastmath: bool = False) -> Union[float, Float32]:
63-
"""
64-
Compute the sigmoid of the input tensor.
65-
"""
66-
return cute.arch.rcp_approx(1.0 + cute.math.exp(-a, fastmath=fastmath))
67-
68-
69-
def silu_f32(a: Union[float, Float32], fastmath: bool = False) -> Union[float, Float32]:
70-
"""
71-
Compute the silu of the input tensor.
72-
"""
73-
return a * sigmoid_f32(a, fastmath=fastmath)
7440

41+
from .utils import (
42+
TRTLLM_ENABLE_PDL,
43+
fmin,
44+
griddepcontrol_launch_dependents,
45+
griddepcontrol_wait,
46+
is_power_of_2,
47+
silu_f32,
48+
)
7549

7650
"""
7751
High-performance persistent blockscaled contiguous grouped dense GEMM (C = alpha * (SFA * A) * (SFB * B)) example for
@@ -749,6 +723,7 @@ class SharedStorage:
749723
smem=self.shared_storage.size_in_bytes(),
750724
stream=stream,
751725
min_blocks_per_mp=1,
726+
use_pdl=TRTLLM_ENABLE_PDL,
752727
)
753728
return
754729

@@ -1087,6 +1062,8 @@ def kernel(
10871062
else:
10881063
self.cta_sync_barrier.arrive_and_wait()
10891064

1065+
griddepcontrol_wait()
1066+
10901067
#
10911068
# Specialized Schedule warp
10921069
#
@@ -1949,6 +1926,8 @@ def kernel(
19491926
#
19501927
c_pipeline.producer_tail()
19511928

1929+
griddepcontrol_launch_dependents()
1930+
19521931
def epilog_tmem_copy_and_partition(
19531932
self,
19541933
tidx: cutlass.Int32,

tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@
5555
from cutlass.cute.nvgpu import cpasync, tcgen05
5656

5757
from .custom_pipeline import PipelineTmaUmma, PipelineUmmaAsync
58-
from .utils import is_power_of_2
58+
from .utils import (TRTLLM_ENABLE_PDL, griddepcontrol_launch_dependents,
59+
griddepcontrol_wait, is_power_of_2)
5960

6061

6162
class Sm100BlockScaledPersistentDenseGemmKernel:
@@ -578,6 +579,7 @@ class SharedStorage:
578579
smem=self.shared_storage.size_in_bytes(),
579580
min_blocks_per_mp=1,
580581
stream=stream,
582+
use_pdl=TRTLLM_ENABLE_PDL,
581583
)
582584
return
583585

@@ -869,6 +871,8 @@ def kernel(
869871
cute.arch.barrier(barrier_id=self.cta_sync_bar_id,
870872
number_of_threads=self.threads_per_cta)
871873

874+
griddepcontrol_wait()
875+
872876
#
873877
# Specialized TMA load warp
874878
#
@@ -1473,6 +1477,8 @@ def kernel(
14731477
#
14741478
c_pipeline.producer_tail()
14751479

1480+
griddepcontrol_launch_dependents()
1481+
14761482
def mainloop_s2t_copy_and_partition(
14771483
self,
14781484
sSF: cute.Tensor,

0 commit comments

Comments
 (0)