Skip to content

Commit 62b7718

Browse files
authored
[TRTLLM-9389][chore] Refactor AlltoallMethodType. (#9388)
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
1 parent 2d5eadf commit 62b7718

File tree

5 files changed

+95
-122
lines changed

5 files changed

+95
-122
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 36 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -143,34 +143,29 @@ def __init__(
143143
if self.enable_alltoall:
144144
self.use_low_precision_combine = model_config.use_low_precision_moe_combine
145145

146-
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
147-
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
148-
MnnvlMemory.initialize()
149-
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
150-
model_config.mapping)
151-
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
152-
model_config.mapping)
153-
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
154-
workspace_mb = int(
155-
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048"))
156-
self.moe_a2a = MoeAlltoAll(
157-
mapping=self.mapping,
158-
max_num_tokens=model_config.max_num_tokens,
159-
top_k=self.routing_method.experts_per_token,
160-
num_experts=self.num_slots,
161-
workspace_size_per_rank=workspace_mb * 1024 * 1024,
162-
)
163-
else:
164-
raise ValueError(
165-
f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}"
166-
)
146+
if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided:
147+
MnnvlMemory.initialize()
148+
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
149+
model_config.mapping)
150+
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
151+
model_config.mapping)
152+
elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided:
153+
workspace_mb = int(
154+
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048"))
155+
self.moe_a2a = MoeAlltoAll(
156+
mapping=self.mapping,
157+
max_num_tokens=model_config.max_num_tokens,
158+
top_k=self.routing_method.experts_per_token,
159+
num_experts=self.num_slots,
160+
workspace_size_per_rank=workspace_mb * 1024 * 1024,
161+
)
167162
elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
168163
raise NotImplementedError(
169164
"DeepEP and DeepEPLowLatency are not supported for CutlassFusedMoE yet"
170165
)
171166
else:
172167
raise NotImplementedError(
173-
f"Not available alltoall method type: {self.alltoall_method_type!r}"
168+
f"Unsupported alltoall method type: {self.alltoall_method_type!r}"
174169
)
175170

176171
# If True, the router weight will be multiplied on the input rather than at the end of FC2
@@ -236,28 +231,18 @@ def select_alltoall_method_type(self) -> AlltoallMethodType:
236231
)
237232
return AlltoallMethodType[all2all_method_type]
238233

239-
if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
240-
return AlltoallMethodType.NotEnabled
241-
242-
# TODO: We found that MNNVL performs better than NCCL AllGather/ReduceScatter,
243-
# regardless of the relationship between EP size and topK. We favor AlltoAll for now.
234+
# TODO: We found that NVLinkOneSided performs better than NCCL AllGather/ReduceScatter,
235+
# regardless of the relationship between EP size and topK. We favor NVLinkOneSided for now.
244236
# if not self.mapping.moe_ep_size > self.routing_method.experts_per_token:
245237
# return AlltoallMethodType.NotEnabled
246-
247-
return AlltoallMethodType.MNNVL
238+
return AlltoallMethodType.NVLinkOneSided
248239

249240
@cached_property
250241
def enable_alltoall(self):
251242
""" enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
252243
"""
253244
return self.alltoall_method_type != AlltoallMethodType.NotEnabled
254245

255-
@cached_property
256-
def moe_alltoall_backend(self):
257-
# "NVLINK_ONE_SIDED" (default) or "NVLINK_TWO_SIDED"
258-
return os.environ.get("TRTLLM_MOE_ALLTOALL_BACKEND",
259-
"NVLINK_ONE_SIDED").strip().upper()
260-
261246
def _supports_load_balancer(self) -> bool:
262247
"""CutlassFusedMoE supports load balancer."""
263248
return True
@@ -329,7 +314,7 @@ def forward_chunk(
329314

330315
if self.layer_load_balancer:
331316
self._load_balancer_done_wait_gpu_stage(is_first_call)
332-
ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "NVLINK_TWO_SIDED"
317+
ignore_allreduce = self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided
333318
self._load_balancer_update_statistic(
334319
token_selected_experts,
335320
is_first_call,
@@ -440,7 +425,7 @@ def forward_chunk(
440425
token_final_scales = torch.ones_like(token_selected_slots,
441426
dtype=torch.float32)
442427

443-
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
428+
if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided:
444429
assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized"
445430
if is_last_call:
446431
loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor(
@@ -473,7 +458,7 @@ def forward_chunk(
473458
token_selected_slots, alltoall_info.recv_rank_count_cumsum,
474459
runtime_max_tokens_per_rank, top_k, self.num_slots,
475460
self.ep_size)
476-
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
461+
elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided:
477462
# Python MoeAlltoAll path
478463
if x_sf is not None:
479464
x_sf = x_sf.view(x_row,
@@ -511,7 +496,7 @@ def forward_chunk(
511496
-1, token_final_scales_recv.shape[-1])
512497
else:
513498
raise ValueError(
514-
f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}"
499+
f"Unsupported moe alltoall method type: {self.alltoall_method_type}"
515500
)
516501

517502
elif run_post_quant_allgather:
@@ -533,7 +518,7 @@ def forward_chunk(
533518

534519
# Optionally provide an output tensor to fused_moe so it writes directly to our buffer
535520
moe_output: Optional[torch.Tensor] = None
536-
if self.enable_alltoall and self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
521+
if self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided:
537522
# Retrieve a workspace-backed output tensor sized by runtime tokens
538523
runtime_max_tokens_per_rank = max(
539524
all_rank_num_tokens) if all_rank_num_tokens else x.shape[0]
@@ -584,7 +569,7 @@ def forward_chunk(
584569

585570
# Combine results if using alltoall
586571
if self.enable_alltoall:
587-
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
572+
if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided:
588573
if alltoall_info is not None:
589574
top_k = self.routing_method.experts_per_token
590575
final_hidden_states = MnnvlMoe.mnnvl_moe_alltoallv_combine(
@@ -597,7 +582,7 @@ def forward_chunk(
597582
use_low_precision_combine=self.
598583
use_low_precision_combine,
599584
token_count=token_count)
600-
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
585+
elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided:
601586
output_hidden_size = final_hidden_states.shape[-1]
602587
runtime_max_tokens_per_rank = max(
603588
all_rank_num_tokens) if all_rank_num_tokens else token_count
@@ -609,7 +594,7 @@ def forward_chunk(
609594
payload_in_workspace=True)
610595
else:
611596
raise ValueError(
612-
f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}"
597+
f"Unsupported moe alltoall method type: {self.alltoall_method_type}"
613598
)
614599

615600
self._load_balancer_done_set_cpu_stage(is_last_call)
@@ -709,7 +694,10 @@ def _reducescatter_or_allreduce(x_, idx):
709694
# Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap
710695
for idx_chunk, (x, router_logits) in enumerate(
711696
zip(x_list, router_logits_list)):
712-
if not (self.alltoall_method_type == AlltoallMethodType.MNNVL):
697+
if not (self.alltoall_method_type
698+
== AlltoallMethodType.NVLinkOneSided
699+
or self.alltoall_method_type
700+
== AlltoallMethodType.NVLinkTwoSided):
713701
if idx_chunk % 2 == 0:
714702
with torch.cuda.stream(self.aux_stream):
715703
outputs = _forward_chunk(x, router_logits,
@@ -727,7 +715,10 @@ def _reducescatter_or_allreduce(x_, idx):
727715

728716
outputs_list.append(outputs)
729717

730-
if not (self.alltoall_method_type == AlltoallMethodType.MNNVL):
718+
if not (self.alltoall_method_type
719+
== AlltoallMethodType.NVLinkOneSided
720+
or self.alltoall_method_type
721+
== AlltoallMethodType.NVLinkTwoSided):
731722
if num_chunks % 2 == 0:
732723
outputs_list[-1] = _reducescatter_or_allreduce(
733724
outputs_list[-1], -1)

tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py

Lines changed: 28 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -121,34 +121,29 @@ def __init__(
121121
if self.enable_alltoall:
122122
self.use_low_precision_combine = model_config.use_low_precision_moe_combine
123123

124-
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
125-
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
126-
MnnvlMemory.initialize()
127-
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
128-
model_config.mapping)
129-
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
130-
model_config.mapping)
131-
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
132-
workspace_mb = int(
133-
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048"))
134-
self.moe_a2a = MoeAlltoAll(
135-
mapping=self.mapping,
136-
max_num_tokens=model_config.max_num_tokens,
137-
top_k=self.routing_method.experts_per_token,
138-
num_experts=self.num_slots,
139-
workspace_size_per_rank=workspace_mb * 1024 * 1024,
140-
)
141-
else:
142-
raise ValueError(
143-
f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}"
144-
)
124+
if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided:
125+
MnnvlMemory.initialize()
126+
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
127+
model_config.mapping)
128+
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
129+
model_config.mapping)
130+
elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided:
131+
workspace_mb = int(
132+
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048"))
133+
self.moe_a2a = MoeAlltoAll(
134+
mapping=self.mapping,
135+
max_num_tokens=model_config.max_num_tokens,
136+
top_k=self.routing_method.experts_per_token,
137+
num_experts=self.num_slots,
138+
workspace_size_per_rank=workspace_mb * 1024 * 1024,
139+
)
145140
elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
146141
raise NotImplementedError(
147142
"DeepEP and DeepEPLowLatency are not supported for TRTLLMGenFusedMoE yet"
148143
)
149144
else:
150145
raise NotImplementedError(
151-
f"Not available alltoall method type: {self.alltoall_method_type!r}"
146+
f"Unsupported alltoall method type: {self.alltoall_method_type!r}"
152147
)
153148

154149
self._weights_created = False
@@ -178,15 +173,11 @@ def select_alltoall_method_type(self) -> AlltoallMethodType:
178173
)
179174
return AlltoallMethodType[all2all_method_type]
180175

181-
if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
182-
return AlltoallMethodType.NotEnabled
183-
184-
# TODO: We found that MNNVL performs better than NCCL AllGather/ReduceScatter,
185-
# regardless of the relationship between EP size and topK. We favor AlltoAll for now.
176+
# TODO: We found that NVLinkOneSided performs better than NCCL AllGather/ReduceScatter,
177+
# regardless of the relationship between EP size and topK. We favor NVLinkOneSided for now.
186178
# if not self.mapping.moe_ep_size > self.routing_method.experts_per_token:
187179
# return AlltoallMethodType.NotEnabled
188-
189-
return AlltoallMethodType.MNNVL
180+
return AlltoallMethodType.NVLinkOneSided
190181

191182
def _supports_load_balancer(self) -> bool:
192183
"""TRTLLMGenFusedMoE supports load balancer."""
@@ -198,12 +189,6 @@ def enable_alltoall(self):
198189
"""
199190
return self.alltoall_method_type != AlltoallMethodType.NotEnabled
200191

201-
@cached_property
202-
def moe_alltoall_backend(self):
203-
# "NVLINK_ONE_SIDED" (default) or "NVLINK_TWO_SIDED"
204-
return os.environ.get("TRTLLM_MOE_ALLTOALL_BACKEND",
205-
"NVLINK_ONE_SIDED").strip().upper()
206-
207192
def _check_configs(self):
208193
assert self.has_deepseek_fp8_block_scales \
209194
or self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8 \
@@ -375,7 +360,7 @@ def forward_impl(
375360

376361
self._load_balancer_done_wait_gpu_stage(is_first_call)
377362

378-
ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "NVLINK_TWO_SIDED"
363+
ignore_allreduce = self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided
379364
self._load_balancer_update_statistic(
380365
token_selected_experts,
381366
is_first_call,
@@ -407,7 +392,7 @@ def forward_impl(
407392
else:
408393
token_final_scales = token_final_scales.to(torch.float32)
409394

410-
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
395+
if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided:
411396
assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized"
412397
if is_last_call:
413398
loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor(
@@ -457,7 +442,7 @@ def forward_impl(
457442

458443
if token_final_scales is not None:
459444
token_final_scales = token_final_scales.to(torch.bfloat16)
460-
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
445+
elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided:
461446
if x_sf is not None:
462447
x_sf = x_sf.view(x_row,
463448
ceil_div(x_col, self.scaling_vector_size))
@@ -499,7 +484,7 @@ def forward_impl(
499484
token_final_scales = token_final_scales.to(torch.bfloat16)
500485
else:
501486
raise ValueError(
502-
f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}"
487+
f"Unsupported moe alltoall method type: {self.alltoall_method_type}"
503488
)
504489

505490
elif run_post_quant_allgather:
@@ -523,7 +508,7 @@ def forward_impl(
523508
moe_output: Optional[torch.Tensor] = None
524509
use_workspace_output = False
525510
# TODO: use_workspace_output only supports w4a8_mxfp4_mxfp8 (gpt-oss) for now
526-
if self.enable_alltoall and self.moe_alltoall_backend == "NVLINK_ONE_SIDED" and self.has_w4a8_mxfp4_mxfp8:
511+
if self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided and self.has_w4a8_mxfp4_mxfp8:
527512
moe_output = self.moe_a2a.get_combine_payload_tensor_in_workspace(
528513
runtime_max_tokens_per_rank, self.hidden_size, torch.bfloat16)
529514
use_workspace_output = True
@@ -787,7 +772,7 @@ def forward_impl(
787772

788773
# Combine results if using alltoall
789774
if self.enable_alltoall:
790-
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
775+
if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided:
791776
if alltoall_info is not None:
792777
final_hidden_states = MnnvlMoe.mnnvl_moe_alltoallv_combine(
793778
final_hidden_states,
@@ -800,7 +785,7 @@ def forward_impl(
800785
use_low_precision_combine,
801786
token_count=token_count,
802787
)
803-
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
788+
elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided:
804789
# If use_workspace_output=True, the MoE result is already in workspace
805790
# Otherwise, we need to reshape and pass it
806791
if use_workspace_output:
@@ -823,7 +808,7 @@ def forward_impl(
823808
payload_in_workspace=False)
824809
else:
825810
raise ValueError(
826-
f"Unsupported moe alltoall backend: {self.moe_alltoall_backend}"
811+
f"Unsupported moe alltoall method type: {self.alltoall_method_type}"
827812
)
828813

829814
final_hidden_states = self.reducescatter_or_allreduce(

0 commit comments

Comments
 (0)