Skip to content

Commit ebb62e1

Browse files
authored
[None][feat] Add alltoall to trtllm-gen MoE backend. (#8481)
Signed-off-by: Bo Li <[email protected]>
1 parent ab4b996 commit ebb62e1

File tree

1 file changed

+159
-56
lines changed

1 file changed

+159
-56
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py

Lines changed: 159 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import os
2+
from functools import cached_property
13
from typing import Dict, List, Optional, Union
24

35
import torch
46
from torch import nn
57

8+
from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe
69
from tensorrt_llm._utils import get_sm_version
710

811
from ...custom_ops.trtllm_gen_custom_ops import \
@@ -106,10 +109,28 @@ def __init__(
106109
assert len(
107110
self.initial_local_expert_ids) == self.expert_size_per_partition
108111

112+
self.alltoall_workspace = None
113+
self.alltoall_prepare_workspace = None
114+
if self.enable_alltoall:
115+
MnnvlMemory.initialize()
116+
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
117+
model_config.mapping)
118+
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
119+
model_config.mapping)
120+
109121
self._weights_created = False
110122
if not model_config.skip_create_weights_in_init:
111123
self.create_weights()
112124

125+
@cached_property
126+
def enable_alltoall(self):
127+
mapping = self.mapping
128+
routing_experts = self.routing_method.experts_per_token
129+
return (mapping.moe_ep_size > routing_experts
130+
and mapping.enable_attention_dp and mapping.tp_size > 1
131+
and os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") != "1"
132+
and MnnvlMemory.supports_mnnvl())
133+
113134
def _check_configs(self):
114135
assert self.has_deepseek_fp8_block_scales \
115136
or self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8 \
@@ -175,6 +196,48 @@ def load_weights(self, weights: List[Dict]):
175196
def post_load_weights(self):
176197
self.quant_method.post_load_weights(self)
177198

199+
def _quantize_for_post_quant_comm(self, x):
200+
"""Quantize inputs prior to post-communication (alltoall/allgather).
201+
Returns: (x, x_sf, x_row, x_col)
202+
"""
203+
x_row = x.shape[0]
204+
x_col = x.shape[1]
205+
x_sf = None
206+
if self.has_w4a8_mxfp4_fp8:
207+
x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
208+
x, self.fc31_input_dequant[0])
209+
x_row, x_col = x.shape[0], x.shape[1]
210+
elif self.has_nvfp4:
211+
if isinstance(x, Fp4QuantizedTensor):
212+
assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before communication"
213+
x_row = x.shape[0]
214+
x_col = x.shape[1] * 2
215+
x, x_sf = x.fp4_tensor, x.scaling_factor
216+
else:
217+
x_row = x.shape[0]
218+
x_col = x.shape[1]
219+
x, x_sf = torch.ops.trtllm.fp4_quantize(
220+
x, self.fc31_input_scale, self.scaling_vector_size, False,
221+
False)
222+
elif self.has_w4a8_mxfp4_mxfp8:
223+
x, x_sf = torch.ops.trtllm.mxfp8_quantize(
224+
x, False, alignment=self.quant_method.weight_alignment)
225+
x_row, x_col = x.shape[0], x.shape[1]
226+
elif self.has_deepseek_fp8_block_scales:
227+
# No change required before communication
228+
pass
229+
elif self.has_w4a16_mxfp4:
230+
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
231+
x = torch.nn.functional.pad(x, (0, pad_size))
232+
elif self.has_w4a8_nvfp4_fp8:
233+
x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
234+
x, 1.0 / self.fc31_input_scale)
235+
else:
236+
raise ValueError(
237+
f"unsupported quantization mode for post communication: {self.quant_config.quant_mode}"
238+
)
239+
return x, x_sf, x_row, x_col
240+
178241
def forward_impl(
179242
self,
180243
x: Union[torch.Tensor, Fp4QuantizedTensor],
@@ -202,55 +265,80 @@ def forward_impl(
202265
topk_group = None
203266
routed_scaling_factor = None
204267

205-
run_post_quant_allgather = self.use_dp and self.parallel_size > 1
268+
run_post_quant_allgather = (self.use_dp and self.parallel_size > 1
269+
and not self.enable_alltoall)
270+
post_quant_comm = run_post_quant_allgather or self.enable_alltoall
206271

207272
x_sf = None
208273
token_selected_experts = None
209274
token_final_scales = None
210275
x_row = x.shape[0]
211276
x_col = x.shape[1]
212-
if run_post_quant_allgather:
213-
# apply routing
277+
token_count = x.shape[0]
278+
alltoall_info = None
279+
280+
if post_quant_comm:
214281
token_selected_experts, token_final_scales = self.routing_method.apply(
215282
router_logits)
216-
token_final_scales = token_final_scales.to(torch.bfloat16)
217-
assert token_final_scales.dtype == torch.bfloat16
218-
assert token_selected_experts.dtype == torch.int32
219-
# quantize inputs
220-
if self.has_w4a8_mxfp4_fp8:
221-
x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
222-
x, self.fc31_input_dequant[0])
223-
# Update x_row and x_col to the padded shape
224-
x_row, x_col = x.shape[0], x.shape[1]
225-
elif self.has_nvfp4:
226-
if isinstance(x, Fp4QuantizedTensor):
227-
assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before communication"
228-
x_row = x.shape[0]
229-
# note: we use uint8 to store 2 fp4 values
230-
x_col = x.shape[1] * 2
231-
x, x_sf = x.fp4_tensor, x.scaling_factor
232-
else:
233-
x_row = x.shape[0]
234-
x_col = x.shape[1]
235-
x, x_sf = torch.ops.trtllm.fp4_quantize(
236-
x, self.fc31_input_scale, self.scaling_vector_size,
237-
False, False)
238-
elif self.has_w4a8_mxfp4_mxfp8:
239-
x, x_sf = torch.ops.trtllm.mxfp8_quantize(
240-
x, False, alignment=self.quant_method.weight_alignment)
241-
# Update x_row and x_col to the padded shape
242-
x_row, x_col = x.shape[0], x.shape[1]
243-
elif self.has_deepseek_fp8_block_scales:
244-
pass
245-
elif self.has_w4a16_mxfp4:
246-
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
247-
x = torch.nn.functional.pad(x, (0, pad_size))
283+
token_selected_experts = token_selected_experts.to(torch.int32)
284+
if token_final_scales is not None:
285+
token_final_scales = token_final_scales.to(torch.bfloat16)
286+
287+
x, x_sf, x_row, x_col = self._quantize_for_post_quant_comm(x)
288+
289+
if self.enable_alltoall:
290+
assert all_rank_num_tokens is not None, "all_rank_num_tokens required for alltoall"
291+
292+
max_num_token = max(
293+
all_rank_num_tokens) if all_rank_num_tokens else token_count
294+
295+
if token_final_scales is None:
296+
token_final_scales = torch.ones_like(token_selected_experts,
297+
dtype=torch.float32)
248298
else:
249-
raise ValueError(
250-
f"unsupported quantization mode with run_post_quant_allgather: {self.quant_config.quant_mode}"
251-
)
299+
token_final_scales = token_final_scales.to(torch.float32)
300+
301+
assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized"
302+
alltoall_info, _ = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
303+
token_selected_experts,
304+
None,
305+
self.alltoall_prepare_workspace,
306+
max_num_token,
307+
self.ep_rank,
308+
self.ep_size,
309+
self.num_experts,
310+
self.num_slots,
311+
top_k,
312+
)
252313

253-
#allgather for attention DP
314+
if x_sf is not None:
315+
x_sf = x_sf.view(x_row, ceil_div(x_col,
316+
self.scaling_vector_size))
317+
318+
x, x_sf, token_selected_experts, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv(
319+
[x, x_sf, token_selected_experts, token_final_scales],
320+
alltoall_info,
321+
self.alltoall_workspace,
322+
self.ep_rank,
323+
self.ep_size,
324+
)
325+
326+
torch.ops.trtllm.memset_expert_ids(
327+
token_selected_experts,
328+
alltoall_info.recv_rank_count_cumsum,
329+
max_num_token,
330+
top_k,
331+
self.num_slots,
332+
self.ep_size,
333+
)
334+
335+
if x_sf is not None:
336+
x_sf = x_sf.flatten()
337+
338+
if token_final_scales is not None:
339+
token_final_scales = token_final_scales.to(torch.bfloat16)
340+
341+
elif run_post_quant_allgather:
254342
if x_sf is not None:
255343
x_sf = x_sf.view(x_row, ceil_div(x_col,
256344
self.scaling_vector_size))
@@ -265,15 +353,18 @@ def forward_impl(
265353
if x_sf is not None:
266354
x_sf = x_sf.flatten()
267355

356+
router_logits_arg = router_logits if not post_quant_comm else None
357+
routing_bias_arg = routing_bias if not post_quant_comm else None
358+
268359
# TODO: since routing kernel is integrated into moe_runner for fp8,
269360
# here we just route the I/Os for moe_runner
270361
if self.has_deepseek_fp8_block_scales:
271362
assert do_finalize, "fp8_block_scale_moe_runner does not support do_finalize=False"
272363
x_val, x_scale = torch.ops.trtllm.fp8_quantize_1x128(x)
273364

274365
final_hidden_states = torch.ops.trtllm.fp8_block_scale_moe_runner(
275-
router_logits if not run_post_quant_allgather else None,
276-
routing_bias if not run_post_quant_allgather else None,
366+
router_logits_arg,
367+
routing_bias_arg,
277368
x_val,
278369
x_scale,
279370
self.w3_w1_weight,
@@ -297,7 +388,7 @@ def forward_impl(
297388
scale_factor_use_ue8m0 = False
298389
is_scale_factor_swizzled = False # use linear layout here
299390

300-
if not run_post_quant_allgather:
391+
if not post_quant_comm:
301392
hidden_states_fp4, hidden_states_scale_linear_fp4 = (
302393
torch.ops.trtllm.fp4_quantize(
303394
x,
@@ -310,8 +401,8 @@ def forward_impl(
310401
hidden_states_fp4, hidden_states_scale_linear_fp4 = x, x_sf
311402

312403
outputs = torch.ops.trtllm.fp4_block_scale_moe_runner(
313-
router_logits if not run_post_quant_allgather else None,
314-
routing_bias if not run_post_quant_allgather else None,
404+
router_logits_arg,
405+
routing_bias_arg,
315406
hidden_states_fp4,
316407
hidden_states_scale_linear_fp4.view(torch.float8_e4m3fn),
317408
self.w3_w1_weight,
@@ -343,7 +434,7 @@ def forward_impl(
343434
final_hidden_states = outputs[0]
344435
elif self.has_w4a16_mxfp4:
345436
assert x.dtype == torch.bfloat16
346-
if not run_post_quant_allgather:
437+
if not post_quant_comm:
347438
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
348439
x = torch.nn.functional.pad(x, (0, pad_size))
349440
else:
@@ -352,8 +443,8 @@ def forward_impl(
352443
intermediate_size_per_partition_padded = self.w3_w1_weight.shape[
353444
-2] // 2
354445
final_hidden_states = torch.ops.trtllm.bf16_mxe2m1_block_scale_moe_runner(
355-
router_logits if not run_post_quant_allgather else None,
356-
routing_bias if not run_post_quant_allgather else None,
446+
router_logits_arg,
447+
routing_bias_arg,
357448
x,
358449
self.w3_w1_weight,
359450
self.w3_w1_weight_scale,
@@ -383,15 +474,15 @@ def forward_impl(
383474
hidden_size].contiguous()
384475
elif self.has_w4a8_nvfp4_fp8:
385476

386-
if not run_post_quant_allgather:
477+
if not post_quant_comm:
387478
hidden_states_fp8, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
388479
x, 1.0 / self.fc31_input_scale)
389480
else:
390481
hidden_states_fp8 = x
391482

392483
outputs = torch.ops.trtllm.fp8_fp4_block_scale_moe_runner(
393-
router_logits,
394-
routing_bias,
484+
router_logits_arg,
485+
routing_bias_arg,
395486
hidden_states_fp8,
396487
self.w3_w1_weight,
397488
self.w3_w1_weight_scale.view(torch.float8_e4m3fn),
@@ -423,7 +514,7 @@ def forward_impl(
423514
final_hidden_states = outputs[0]
424515
elif self.has_w4a8_mxfp4_fp8:
425516
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
426-
if not run_post_quant_allgather:
517+
if not post_quant_comm:
427518
x = torch.nn.functional.pad(x, (0, pad_size))
428519
x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
429520
x, self.fc31_input_gate_dequant[0])
@@ -433,8 +524,8 @@ def forward_impl(
433524
-2] // 2
434525

435526
final_hidden_states = torch.ops.trtllm.e4m3_mxe2m1_block_scale_moe_runner(
436-
router_logits if not run_post_quant_allgather else None,
437-
routing_bias if not run_post_quant_allgather else None,
527+
router_logits_arg,
528+
routing_bias_arg,
438529
x,
439530
self.w3_w1_weight,
440531
self.w3_w1_weight_scale,
@@ -466,7 +557,7 @@ def forward_impl(
466557
final_hidden_states = final_hidden_states[:, :self.
467558
hidden_size].contiguous()
468559
elif self.has_w4a8_mxfp4_mxfp8:
469-
if not run_post_quant_allgather:
560+
if not post_quant_comm:
470561
# TRTLLM-Gen uses linear SF layout for the mxfp8 input.
471562
mxfp8_x, sf = torch.ops.trtllm.mxfp8_quantize(
472563
x, False, alignment=self.quant_method.weight_alignment)
@@ -477,8 +568,8 @@ def forward_impl(
477568
-2] // 2
478569

479570
final_hidden_states = torch.ops.trtllm.mxe4m3_mxe2m1_block_scale_moe_runner(
480-
router_logits if not run_post_quant_allgather else None,
481-
routing_bias if not run_post_quant_allgather else None,
571+
router_logits_arg,
572+
routing_bias_arg,
482573
mxfp8_x,
483574
sf,
484575
self.w3_w1_weight,
@@ -511,6 +602,18 @@ def forward_impl(
511602
"TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_mxfp8 and w4a8_mxfp4_fp8 dtypes."
512603
)
513604

605+
# Combine results if using alltoall
606+
if self.enable_alltoall and alltoall_info is not None:
607+
final_hidden_states = MnnvlMoe.mnnvl_moe_alltoallv_combine(
608+
final_hidden_states,
609+
alltoall_info,
610+
self.alltoall_workspace,
611+
ep_rank=self.ep_rank,
612+
ep_size=self.ep_size,
613+
top_k=top_k,
614+
token_count=token_count,
615+
)
616+
514617
final_hidden_states = self.reducescatter_or_allreduce(
515618
final_hidden_states,
516619
all_rank_num_tokens=all_rank_num_tokens,

0 commit comments

Comments
 (0)