Skip to content

Commit 6fe89ea

Browse files
authored
[TRTLLM-9819][perf] Reuse alltoall workspace for CuteDSL MoE output (#9840)
Signed-off-by: Enwei Zhu <[email protected]>
1 parent 0b279f4 commit 6fe89ea

File tree

7 files changed

+98
-54
lines changed

7 files changed

+98
-54
lines changed

cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,24 +205,26 @@ std::tuple<torch::Tensor, torch::optional<torch::Tensor>> moe_permute(torch::Ten
205205

206206
// Unpermute
207207

208-
torch::Tensor moe_unpermute(torch::Tensor const& permuted_input, torch::Tensor const& expanded_idx_to_permuted_idx,
209-
torch::Tensor const& topk_scales)
208+
void moe_unpermute_inplace(torch::Tensor const& permuted_input, torch::Tensor const& output,
209+
torch::Tensor const& expanded_idx_to_permuted_idx, torch::Tensor const& topk_scales)
210210
{
211211
TORCH_CHECK(permuted_input.dim() == 2, "permuted_input must be 2D.");
212212
int64_t const max_num_permuted_tokens = permuted_input.size(0);
213213
int64_t const hidden_size = permuted_input.size(1);
214+
TORCH_CHECK(output.dim() == 2, "output must be 2D.");
215+
int64_t const num_tokens = output.size(0);
216+
TORCH_CHECK(output.size(1) == hidden_size, "output.size(1) must be hidden_size.");
217+
214218
TORCH_CHECK(expanded_idx_to_permuted_idx.dim() == 2, "expanded_idx_to_permuted_idx must be 2D.");
215-
int64_t const num_tokens = expanded_idx_to_permuted_idx.size(0);
219+
TORCH_CHECK(
220+
expanded_idx_to_permuted_idx.size(0) == num_tokens, "expanded_idx_to_permuted_idx.size(0) must be num_tokens.");
216221
int64_t const top_k = expanded_idx_to_permuted_idx.size(1);
217222
TORCH_CHECK(topk_scales.dim() == 2, "topk_scales must be 2D.");
218223
TORCH_CHECK(topk_scales.size(0) == num_tokens, "topk_scales.size(0) must be num_tokens.");
219224
TORCH_CHECK(topk_scales.size(1) == top_k, "topk_scales.size(1) must be top_k.");
220-
221225
TORCH_CHECK(max_num_permuted_tokens >= num_tokens * top_k,
222226
"max_num_permuted_tokens must be greater than or equal to num_tokens * top_k.");
223227

224-
auto output
225-
= torch::empty({num_tokens, hidden_size}, torch::dtype(permuted_input.scalar_type()).device(torch::kCUDA));
226228
auto const& stream = at::cuda::getCurrentCUDAStream(permuted_input.get_device());
227229

228230
#define DISPATCH_MOE_UNPERMUTE(InputType, TopKScaleType) \
@@ -253,7 +255,19 @@ torch::Tensor moe_unpermute(torch::Tensor const& permuted_input, torch::Tensor c
253255
}
254256

255257
#undef DISPATCH_MOE_UNPERMUTE
258+
}
256259

260+
torch::Tensor moe_unpermute(torch::Tensor const& permuted_input, torch::Tensor const& expanded_idx_to_permuted_idx,
261+
torch::Tensor const& topk_scales)
262+
{
263+
TORCH_CHECK(permuted_input.dim() == 2, "permuted_input must be 2D.");
264+
int64_t const hidden_size = permuted_input.size(1);
265+
TORCH_CHECK(expanded_idx_to_permuted_idx.dim() == 2, "expanded_idx_to_permuted_idx must be 2D.");
266+
int64_t const num_tokens = expanded_idx_to_permuted_idx.size(0);
267+
268+
auto output
269+
= torch::empty({num_tokens, hidden_size}, torch::dtype(permuted_input.scalar_type()).device(torch::kCUDA));
270+
moe_unpermute_inplace(permuted_input, output, expanded_idx_to_permuted_idx, topk_scales);
257271
return output;
258272
}
259273

@@ -489,6 +503,9 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
489503
m.def(
490504
"moe_permute(Tensor input, Tensor? input_sf, Tensor tile_idx_to_mn_limit, Tensor permuted_idx_to_expanded_idx, "
491505
"Tensor num_non_exiting_tiles, int tile_tokens_dim, int top_k) -> (Tensor, Tensor?)");
506+
m.def(
507+
"moe_unpermute_inplace(Tensor permuted_input, Tensor(a!) output, Tensor expanded_idx_to_permuted_idx, Tensor "
508+
"topk_scales) -> ()");
492509
m.def("moe_unpermute(Tensor permuted_input, Tensor expanded_idx_to_permuted_idx, Tensor topk_scales) -> Tensor");
493510
m.def(
494511
"moe_output_memset_inplace(Tensor(a!) input, Tensor tile_idx_to_mn_limit, Tensor expanded_idx_to_permuted_idx, "
@@ -510,6 +527,7 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
510527
m.impl("moe_topk_sort", &tensorrt_llm::torch_ext::moe_topk_sort);
511528
m.impl("moe_sort", &tensorrt_llm::torch_ext::moe_sort);
512529
m.impl("moe_permute", &tensorrt_llm::torch_ext::moe_permute);
530+
m.impl("moe_unpermute_inplace", &tensorrt_llm::torch_ext::moe_unpermute_inplace);
513531
m.impl("moe_unpermute", &tensorrt_llm::torch_ext::moe_unpermute);
514532
m.impl("moe_output_memset_inplace", &tensorrt_llm::torch_ext::moe_output_memset_inplace);
515533
m.impl("moe_swiglu", &tensorrt_llm::torch_ext::moe_swiglu);

tensorrt_llm/_torch/compilation/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def inplace_info():
7777
torch.ops.trtllm.logits_bitmask.default: {
7878
1: "logits"
7979
},
80+
torch.ops.trtllm.moe_unpermute_inplace.default: {
81+
2: "output"
82+
},
8083
torch.ops.trtllm.moe_output_memset_inplace.default: {
8184
1: "input"
8285
},

tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -991,23 +991,17 @@ def _get_nvlink_onesided_moe_output(
991991
if not isinstance(self.comm, NVLinkOneSided):
992992
return None
993993

994+
if not self.backend.supports_moe_output_in_alltoall_workspace():
995+
# Ensure payload_in_workspace is False if backend doesn't support it
996+
self.comm.payload_in_workspace = False
997+
return None
998+
994999
# Determine workspace dtype and whether backend supports workspace output
9951000
workspace_dtype = output_dtype
996-
backend_supports_workspace = False
997-
9981001
if isinstance(self.backend, TRTLLMGenFusedMoE):
9991002
# TRTLLMGen specific configuration
10001003
self.comm.invalid_token_expert_id = -1
10011004
workspace_dtype = torch.bfloat16
1002-
backend_supports_workspace = self.backend.has_w4a8_mxfp4_mxfp8
1003-
elif isinstance(self.backend, CutlassFusedMoE):
1004-
# Cutlass always supports workspace output with NVLinkOneSided
1005-
backend_supports_workspace = True
1006-
1007-
if not backend_supports_workspace:
1008-
# Ensure payload_in_workspace is False if backend doesn't support it
1009-
self.comm.payload_in_workspace = False
1010-
return None
10111005

10121006
# Calculate runtime max tokens per rank
10131007
assert all_rank_num_tokens is not None, (
@@ -1022,7 +1016,6 @@ def _get_nvlink_onesided_moe_output(
10221016

10231017
# Dynamically enable payload_in_workspace for this forward pass
10241018
self.comm.payload_in_workspace = True
1025-
10261019
return moe_output
10271020

10281021
def _get_backend_kwargs(
@@ -1096,13 +1089,18 @@ def _get_backend_kwargs(
10961089

10971090
# Get moe_output for NVLinkOneSided backend
10981091
kwargs["moe_output"] = self._get_nvlink_onesided_moe_output(
1099-
all_rank_num_tokens, output_dtype
1092+
all_rank_num_tokens=all_rank_num_tokens, output_dtype=output_dtype
11001093
)
11011094

11021095
# CuteDSL-specific parameters
11031096
elif self.backend.__class__ == CuteDslFusedMoE:
11041097
kwargs["enable_alltoall"] = self.enable_alltoall
11051098

1099+
# Get moe_output for NVLinkOneSided backend
1100+
kwargs["moe_output"] = self._get_nvlink_onesided_moe_output(
1101+
all_rank_num_tokens=all_rank_num_tokens, output_dtype=output_dtype
1102+
)
1103+
11061104
# DeepGemm-specific parameters
11071105
elif self.backend.__class__ == DeepGemmFusedMoE:
11081106
if workspace is not None:
@@ -1123,7 +1121,7 @@ def _get_backend_kwargs(
11231121

11241122
# Get moe_output for NVLinkOneSided backend
11251123
kwargs["moe_output"] = self._get_nvlink_onesided_moe_output(
1126-
all_rank_num_tokens, output_dtype
1124+
all_rank_num_tokens=all_rank_num_tokens, output_dtype=output_dtype
11271125
)
11281126

11291127
return kwargs

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,9 @@ def _get_quant_method(self):
210210
return NVFP4CuteDslFusedMoEMethod()
211211
return super()._get_quant_method()
212212

213+
def supports_moe_output_in_alltoall_workspace(self):
214+
return self.has_nvfp4
215+
213216
def quantize_input(self,
214217
x: Union[torch.Tensor, Fp4QuantizedTensor],
215218
post_quant_comm: bool = True):
@@ -258,6 +261,7 @@ def run_moe_nvfp4(
258261
token_selected_experts: torch.Tensor,
259262
token_final_scales: Optional[torch.Tensor],
260263
x_sf: Optional[torch.Tensor] = None,
264+
moe_output: Optional[torch.Tensor] = None,
261265
enable_alltoall: bool = False,
262266
) -> torch.Tensor:
263267
assert self.has_nvfp4
@@ -274,6 +278,16 @@ def run_moe_nvfp4(
274278
tile_tokens_dim=tile_size,
275279
)
276280

281+
if moe_output is None:
282+
moe_output = torch.empty(
283+
(token_final_scales.size(0), self.hidden_size),
284+
dtype=output_dtype,
285+
device=x.device)
286+
else:
287+
assert moe_output.size() == (token_final_scales.size(0),
288+
self.hidden_size)
289+
assert moe_output.dtype == output_dtype
290+
277291
x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell(
278292
input=x.view(torch.float4_e2m1fn_x2),
279293
weight=self.w3_w1_weight.view(torch.float4_e2m1fn_x2),
@@ -291,12 +305,10 @@ def run_moe_nvfp4(
291305
local_expert_offset=self.slot_start,
292306
tile_size=tile_size,
293307
)
308+
294309
if self.use_fused_finalize:
295-
output = torch.empty((token_final_scales.size(0), self.hidden_size),
296-
dtype=output_dtype,
297-
device=x.device)
298310
torch.ops.trtllm.moe_output_memset_inplace(
299-
input=output,
311+
input=moe_output,
300312
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
301313
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
302314
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
@@ -313,7 +325,7 @@ def run_moe_nvfp4(
313325
weight_scale=self.quant_scales.fc2_weight_block.view(
314326
torch.uint8),
315327
alpha=self.quant_scales.fc2_global,
316-
output=output,
328+
output=moe_output,
317329
tile_idx_to_group_idx=tile_idx_to_expert_idx,
318330
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
319331
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
@@ -326,7 +338,6 @@ def run_moe_nvfp4(
326338
tile_size=tile_size,
327339
output_dtype=output_dtype,
328340
)
329-
x = output
330341
else:
331342
x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_blackwell(
332343
input=x.view(torch.float4_e2m1fn_x2),
@@ -344,12 +355,13 @@ def run_moe_nvfp4(
344355
tile_size=tile_size,
345356
output_dtype=output_dtype,
346357
)
347-
x = torch.ops.trtllm.moe_unpermute(
358+
torch.ops.trtllm.moe_unpermute_inplace(
348359
permuted_input=x,
360+
output=moe_output,
349361
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
350362
topk_scales=token_final_scales,
351363
)
352-
return x
364+
return moe_output
353365

354366
def run_moe_fp8_block_scales(
355367
self,
@@ -364,12 +376,12 @@ def run_moe_fp8_block_scales(
364376
weight_dtype = self.w3_w1_weight.dtype
365377

366378
(
367-
permuted_row_to_unpermuted_row_tensor,
368-
permuted_token_selected_experts_tensor,
369-
permuted_data_tensor,
370-
expert_first_token_offset_tensor,
371-
permuted_token_final_scales_tensor,
372-
unpermuted_row_to_permuted_row_tensor,
379+
permuted_row_to_unpermuted_row,
380+
permuted_token_selected_experts,
381+
x,
382+
expert_first_token_offset,
383+
permuted_token_final_scales,
384+
unpermuted_row_to_permuted_row,
373385
) = torch.ops.trtllm.moe_permute_op(
374386
x,
375387
token_selected_experts,
@@ -388,35 +400,34 @@ def run_moe_fp8_block_scales(
388400
min_latency_mode=False,
389401
use_fp8_block_scaling=True,
390402
)
391-
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
392-
permuted_data_tensor)
393-
h1 = cute_dsl_fp8_group_blockwise_gemm_ref(
394-
a=act_input_fp8,
403+
x, x_sf = torch.ops.trtllm.fp8_quantize_1x128(x)
404+
x = cute_dsl_fp8_group_blockwise_gemm_ref(
405+
a=x,
395406
b=self.w3_w1_weight.view(weight_dtype),
396-
a_sf=act_input_sf,
407+
a_sf=x_sf,
397408
b_sf=self.quant_scales[0],
398-
offset_array=expert_first_token_offset_tensor,
409+
offset_array=expert_first_token_offset,
399410
)
400-
h2 = swiglu_fused_moe(h1)
401-
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(h2)
402-
h3 = cute_dsl_fp8_group_blockwise_gemm_ref(
403-
a=act_input_fp8,
411+
x = swiglu_fused_moe(x)
412+
x, x_sf = torch.ops.trtllm.fp8_quantize_1x128(x)
413+
x = cute_dsl_fp8_group_blockwise_gemm_ref(
414+
a=x,
404415
b=self.w2_weight.view(weight_dtype),
405-
a_sf=act_input_sf,
416+
a_sf=x_sf,
406417
b_sf=self.quant_scales[1],
407-
offset_array=expert_first_token_offset_tensor,
418+
offset_array=expert_first_token_offset,
408419
)
409-
h4 = torch.ops.trtllm.moe_finalize_scale_op(
410-
h3,
420+
x = torch.ops.trtllm.moe_finalize_scale_op(
421+
x,
411422
None, # biases
412423
token_final_scales,
413-
unpermuted_row_to_permuted_row_tensor,
414-
permuted_row_to_unpermuted_row_tensor,
424+
unpermuted_row_to_permuted_row,
425+
permuted_row_to_unpermuted_row,
415426
token_selected_experts,
416-
expert_first_token_offset_tensor,
427+
expert_first_token_offset,
417428
enable_alltoall,
418-
x.shape[0], # num_rows
419-
x.shape[1], # (possibly padded) hidden_size
429+
token_final_scales.size(0), # num_rows
430+
self.hidden_size, # (possibly padded) hidden_size
420431
self.unpadded_hidden_size, # original hidden size
421432
self.routing_method.top_k,
422433
self.expert_size_per_partition, # num_experts_per_node
@@ -425,14 +436,15 @@ def run_moe_fp8_block_scales(
425436
self.ep_size,
426437
self.ep_rank,
427438
)
428-
return h4
439+
return x
429440

430441
def run_moe(
431442
self,
432443
x: torch.Tensor,
433444
token_selected_experts: torch.Tensor,
434445
token_final_scales: Optional[torch.Tensor],
435446
x_sf: Optional[torch.Tensor] = None,
447+
moe_output: Optional[torch.Tensor] = None,
436448
enable_alltoall: bool = False,
437449
) -> torch.Tensor:
438450
"""
@@ -448,6 +460,7 @@ def run_moe(
448460
this represents expert slots [num_tokens, top_k] instead.
449461
token_final_scales: Final scaling factors for each token
450462
x_sf: Input scale factors (optional, for certain quantization schemes)
463+
moe_output: Pre-allocated MoE output buffer (optional, for NVLINK one-sided backend).
451464
enable_alltoall: Whether alltoall communication is enabled.
452465
453466
Returns:
@@ -459,6 +472,7 @@ def run_moe(
459472
token_selected_experts=token_selected_experts,
460473
token_final_scales=token_final_scales,
461474
x_sf=x_sf,
475+
moe_output=moe_output,
462476
enable_alltoall=enable_alltoall)
463477
elif self.has_deepseek_fp8_block_scales:
464478
return self.run_moe_fp8_block_scales(

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,9 @@ def create_weights(self):
389389
self._weights_created = True
390390
self._check_configs()
391391

392+
def supports_moe_output_in_alltoall_workspace(self):
393+
return True
394+
392395
def run_moe(
393396
self,
394397
x: torch.Tensor,

tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,9 @@ def quantize_input(self, x, post_quant_comm: bool = True):
354354

355355
return x, x_sf
356356

357+
def supports_moe_output_in_alltoall_workspace(self):
358+
return self.has_w4a8_mxfp4_mxfp8
359+
357360
def run_moe(
358361
self,
359362
x: torch.Tensor,

tensorrt_llm/_torch/modules/fused_moe/interface.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,11 @@ def enable_alltoall(self):
723723
def expand_intermediate_size_per_partition(self):
724724
return self.intermediate_size_per_partition * self.intermediate_size_expand_ratio
725725

726+
def supports_moe_output_in_alltoall_workspace(self):
727+
""" Supports moe_output in alltoall workspace
728+
"""
729+
return False
730+
726731
def reducescatter_or_allreduce(
727732
self,
728733
inputs,

0 commit comments

Comments
 (0)