Skip to content

Commit 2082a0c

Browse files
committed
fix alltoall
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
1 parent 7fabac3 commit 2082a0c

File tree

5 files changed

+76
-33
lines changed

5 files changed

+76
-33
lines changed

tensorrt_llm/_torch/compilation/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def inplace_info():
8080
torch.ops.trtllm.moe_output_memset_inplace.default: {
8181
1: "input"
8282
},
83-
torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_blackwell.default:
83+
torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell.default:
8484
{
8585
6: "output"
8686
}

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,16 +1182,16 @@ def forward(self, inputs: List[torch.Tensor],
11821182
return c
11831183

11841184
@torch.library.custom_op(
1185-
"trtllm::cute_dsl_nvfp4_grouped_gemm_finalize_blackwell",
1185+
"trtllm::cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell",
11861186
mutates_args=("output", ),
11871187
device_types="cuda")
1188-
def cute_dsl_nvfp4_grouped_gemm_finalize_blackwell(
1188+
def cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell(
11891189
input: torch.Tensor,
11901190
weight: torch.Tensor,
11911191
input_scale: torch.Tensor,
11921192
weight_scale: torch.Tensor,
11931193
alpha: torch.Tensor,
1194-
output: Optional[torch.Tensor],
1194+
output: torch.Tensor,
11951195
tile_idx_to_group_idx: torch.Tensor,
11961196
tile_idx_to_mn_limit: torch.Tensor,
11971197
permuted_idx_to_expanded_idx: torch.Tensor,
@@ -1204,21 +1204,13 @@ def cute_dsl_nvfp4_grouped_gemm_finalize_blackwell(
12041204
tile_size: int,
12051205
output_dtype: torch.dtype,
12061206
scaling_vector_size: int = 16,
1207-
) -> torch.Tensor:
1207+
) -> None:
12081208
tuner = AutoTuner.get()
12091209

12101210
runner = Sm100BlockScaledContiguousGroupedGemmFinalizeFusionRunner(
12111211
num_experts, top_k, num_local_experts, local_expert_offset,
12121212
tile_size, output_dtype, scaling_vector_size)
12131213

1214-
if output is None:
1215-
num_tokens = token_final_scales.size(0)
1216-
n = weight.size(1)
1217-
output = torch.zeros(num_tokens,
1218-
n,
1219-
dtype=output_dtype,
1220-
device=input.device)
1221-
12221214
inputs = [
12231215
input, weight, input_scale, weight_scale, alpha, output,
12241216
tile_idx_to_group_idx, tile_idx_to_mn_limit,
@@ -1227,12 +1219,62 @@ def cute_dsl_nvfp4_grouped_gemm_finalize_blackwell(
12271219
]
12281220

12291221
_, best_tactic = tuner.choose_one(
1230-
"trtllm::cute_dsl_nvfp4_grouped_gemm_finalize_blackwell",
1222+
"trtllm::cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell",
12311223
[runner],
12321224
runner.get_tuning_config(),
12331225
inputs,
12341226
)
1235-
output = runner(inputs, tactic=best_tactic)
1227+
runner(inputs, tactic=best_tactic)
1228+
1229+
@torch.library.custom_op(
1230+
"trtllm::cute_dsl_nvfp4_grouped_gemm_finalize_blackwell",
1231+
mutates_args=(),
1232+
device_types="cuda")
1233+
def cute_dsl_nvfp4_grouped_gemm_finalize_blackwell(
1234+
input: torch.Tensor,
1235+
weight: torch.Tensor,
1236+
input_scale: torch.Tensor,
1237+
weight_scale: torch.Tensor,
1238+
alpha: torch.Tensor,
1239+
tile_idx_to_group_idx: torch.Tensor,
1240+
tile_idx_to_mn_limit: torch.Tensor,
1241+
permuted_idx_to_expanded_idx: torch.Tensor,
1242+
num_non_exiting_tiles: torch.Tensor,
1243+
token_final_scales: torch.Tensor,
1244+
num_experts: int,
1245+
top_k: int,
1246+
num_local_experts: int,
1247+
local_expert_offset: int,
1248+
tile_size: int,
1249+
output_dtype: torch.dtype,
1250+
scaling_vector_size: int = 16,
1251+
) -> torch.Tensor:
1252+
num_tokens = token_final_scales.size(0)
1253+
n = weight.size(1)
1254+
output = torch.zeros(num_tokens,
1255+
n,
1256+
dtype=output_dtype,
1257+
device=input.device)
1258+
torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell(
1259+
input=input,
1260+
weight=weight,
1261+
input_scale=input_scale,
1262+
weight_scale=weight_scale,
1263+
alpha=alpha,
1264+
output=output,
1265+
tile_idx_to_group_idx=tile_idx_to_group_idx,
1266+
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
1267+
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
1268+
num_non_exiting_tiles=num_non_exiting_tiles,
1269+
token_final_scales=token_final_scales,
1270+
num_experts=num_experts,
1271+
top_k=top_k,
1272+
num_local_experts=num_local_experts,
1273+
local_expert_offset=local_expert_offset,
1274+
tile_size=tile_size,
1275+
output_dtype=output_dtype,
1276+
scaling_vector_size=scaling_vector_size,
1277+
)
12361278
return output
12371279

12381280
@torch.library.register_fake(
@@ -1243,7 +1285,6 @@ def _(
12431285
input_scale: torch.Tensor,
12441286
weight_scale: torch.Tensor,
12451287
alpha: torch.Tensor,
1246-
output: Optional[torch.Tensor],
12471288
tile_idx_to_group_idx: torch.Tensor,
12481289
tile_idx_to_mn_limit: torch.Tensor,
12491290
permuted_idx_to_expanded_idx: torch.Tensor,

tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -908,27 +908,27 @@ def _get_backend_kwargs(
908908
kwargs = {}
909909

910910
# Common parameters for Cutlass and DeepGemm
911-
if isinstance(self.backend, (CutlassFusedMoE, DeepGemmFusedMoE)):
911+
if self.backend.__class__ in (CutlassFusedMoE, DeepGemmFusedMoE, CuteDslFusedMoE):
912912
pass
913913

914914
# Cutlass-specific parameters
915-
if isinstance(self.backend, CutlassFusedMoE):
915+
if self.backend.__class__ == CutlassFusedMoE:
916916
pass
917917

918918
# CuteDSL-specific parameters
919-
elif isinstance(self.backend, CuteDslFusedMoE):
919+
elif self.backend.__class__ == CuteDslFusedMoE:
920920
kwargs["enable_alltoall"] = self.enable_alltoall
921921

922922
# WideEP-specific parameters
923-
elif isinstance(self.backend, WideEPMoE):
923+
elif self.backend.__class__ == WideEPMoE:
924924
pass
925925

926926
# DeepGemm-specific parameters
927-
elif isinstance(self.backend, DeepGemmFusedMoE):
927+
elif self.backend.__class__ == DeepGemmFusedMoE:
928928
pass
929929

930930
# TRTLLMGen-specific parameters
931-
elif isinstance(self.backend, TRTLLMGenFusedMoE):
931+
elif self.backend.__class__ == TRTLLMGenFusedMoE:
932932
# Determine router_logits based on whether routing has been done
933933
# If backend doesn't support load balancer, routing is done before communication
934934
# In that case, router_logits should be None (routing already done)

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,6 @@ def run_moe_nvfp4(
259259
enable_alltoall: bool = False,
260260
) -> torch.Tensor:
261261
assert self.has_nvfp4
262-
output_shape = x.size()
263262
output_dtype = torch.bfloat16
264263
tile_size = 128
265264

@@ -297,21 +296,22 @@ def run_moe_nvfp4(
297296
tile_size=tile_size,
298297
)
299298
if self.use_fused_finalize:
300-
output = None
299+
output = torch.empty((token_final_scales.size(0), self.hidden_size),
300+
dtype=output_dtype,
301+
device=x.device)
301302
if enable_alltoall:
302-
output = torch.empty(output_shape,
303-
dtype=output_dtype,
304-
device=x.device)
305303
torch.ops.trtllm.moe_output_memset_inplace(
306-
output=output,
304+
input=output,
307305
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
308306
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
309307
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
310308
num_non_exiting_tiles=num_non_exiting_tiles,
311-
tile_size=tile_size,
309+
tile_tokens_dim=tile_size,
312310
top_k=self.routing_method.experts_per_token,
313311
)
314-
x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_blackwell(
312+
else:
313+
output.fill_(0)
314+
torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell(
315315
input=x.view(torch.float4_e2m1fn_x2),
316316
weight=self.w2_weight.view(torch.float4_e2m1fn_x2),
317317
input_scale=x_sf.view(torch.uint8),
@@ -331,6 +331,7 @@ def run_moe_nvfp4(
331331
tile_size=tile_size,
332332
output_dtype=output_dtype,
333333
)
334+
x = output
334335
else:
335336
x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_blackwell(
336337
input=x.view(torch.float4_e2m1fn_x2),
@@ -462,13 +463,15 @@ def run_moe(
462463
x=x,
463464
token_selected_experts=token_selected_experts,
464465
token_final_scales=token_final_scales,
465-
x_sf=x_sf)
466+
x_sf=x_sf,
467+
enable_alltoall=enable_alltoall)
466468
elif self.has_deepseek_fp8_block_scales:
467469
return self.run_moe_fp8_block_scales(
468470
x=x,
469471
token_selected_experts=token_selected_experts,
470472
token_final_scales=token_final_scales,
471-
x_sf=x_sf)
473+
x_sf=x_sf,
474+
enable_alltoall=enable_alltoall)
472475
else:
473476
raise ValueError(
474477
f"{self.__class__.__name__} doesn't support quantization mode {self.quant_config.quant_mode}."

tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,6 @@ def test_nvfp4_grouped_gemm_finalize_blackwell(
535535
a_sf,
536536
b_sf,
537537
alpha,
538-
None, # output
539538
tile_idx_to_group_idx,
540539
tile_idx_to_mn_limit,
541540
permuted_idx_to_expanded_idx,

0 commit comments

Comments
 (0)