@@ -143,34 +143,29 @@ def __init__(
143143 if self .enable_alltoall :
144144 self .use_low_precision_combine = model_config .use_low_precision_moe_combine
145145
146- if self .alltoall_method_type == AlltoallMethodType .MNNVL :
147- if self .moe_alltoall_backend == "NVLINK_TWO_SIDED" :
148- MnnvlMemory .initialize ()
149- self .alltoall_workspace = MnnvlMoe .get_moe_workspaces (
150- model_config .mapping )
151- self .alltoall_prepare_workspace = MnnvlMoe .get_moe_prepare_workspace (
152- model_config .mapping )
153- elif self .moe_alltoall_backend == "NVLINK_ONE_SIDED" :
154- workspace_mb = int (
155- os .environ .get ("TRTLLM_MOE_A2A_WORKSPACE_MB" , "2048" ))
156- self .moe_a2a = MoeAlltoAll (
157- mapping = self .mapping ,
158- max_num_tokens = model_config .max_num_tokens ,
159- top_k = self .routing_method .experts_per_token ,
160- num_experts = self .num_slots ,
161- workspace_size_per_rank = workspace_mb * 1024 * 1024 ,
162- )
163- else :
164- raise ValueError (
165- f"Unsupported moe alltoall backend: { self .moe_alltoall_backend } "
166- )
146+ if self .alltoall_method_type == AlltoallMethodType .NVLinkTwoSided :
147+ MnnvlMemory .initialize ()
148+ self .alltoall_workspace = MnnvlMoe .get_moe_workspaces (
149+ model_config .mapping )
150+ self .alltoall_prepare_workspace = MnnvlMoe .get_moe_prepare_workspace (
151+ model_config .mapping )
152+ elif self .alltoall_method_type == AlltoallMethodType .NVLinkOneSided :
153+ workspace_mb = int (
154+ os .environ .get ("TRTLLM_MOE_A2A_WORKSPACE_MB" , "2048" ))
155+ self .moe_a2a = MoeAlltoAll (
156+ mapping = self .mapping ,
157+ max_num_tokens = model_config .max_num_tokens ,
158+ top_k = self .routing_method .experts_per_token ,
159+ num_experts = self .num_slots ,
160+ workspace_size_per_rank = workspace_mb * 1024 * 1024 ,
161+ )
167162 elif self .alltoall_method_type == AlltoallMethodType .DeepEP or self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency :
168163 raise NotImplementedError (
169164 "DeepEP and DeepEPLowLatency are not supported for CutlassFusedMoE yet"
170165 )
171166 else :
172167 raise NotImplementedError (
173- f"Not available alltoall method type: { self .alltoall_method_type !r} "
168+ f"Unsupported alltoall method type: { self .alltoall_method_type !r} "
174169 )
175170
176171 # If True, the router weight will be multiplied on the input rather than at the end of FC2
@@ -236,28 +231,18 @@ def select_alltoall_method_type(self) -> AlltoallMethodType:
236231 )
237232 return AlltoallMethodType [all2all_method_type ]
238233
239- if os .environ .get ("TRTLLM_MOE_DISABLE_ALLTOALLV" , "0" ) == "1" :
240- return AlltoallMethodType .NotEnabled
241-
242- # TODO: We found that MNNVL performs better than NCCL AllGather/ReduceScatter,
243- # regardless of the relationship between EP size and topK. We favor AlltoAll for now.
234+ # TODO: We found that NVLinkOneSided performs better than NCCL AllGather/ReduceScatter,
235+ # regardless of the relationship between EP size and topK. We favor NVLinkOneSided for now.
244236 # if not self.mapping.moe_ep_size > self.routing_method.experts_per_token:
245237 # return AlltoallMethodType.NotEnabled
246-
247- return AlltoallMethodType .MNNVL
238+ return AlltoallMethodType .NVLinkOneSided
248239
249240 @cached_property
250241 def enable_alltoall (self ):
251242 """ enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
252243 """
253244 return self .alltoall_method_type != AlltoallMethodType .NotEnabled
254245
255- @cached_property
256- def moe_alltoall_backend (self ):
257- # "NVLINK_ONE_SIDED" (default) or "NVLINK_TWO_SIDED"
258- return os .environ .get ("TRTLLM_MOE_ALLTOALL_BACKEND" ,
259- "NVLINK_ONE_SIDED" ).strip ().upper ()
260-
261246 def _supports_load_balancer (self ) -> bool :
262247 """CutlassFusedMoE supports load balancer."""
263248 return True
@@ -329,7 +314,7 @@ def forward_chunk(
329314
330315 if self .layer_load_balancer :
331316 self ._load_balancer_done_wait_gpu_stage (is_first_call )
332- ignore_allreduce = self .enable_alltoall and self . alltoall_method_type == AlltoallMethodType .MNNVL and self . moe_alltoall_backend == "NVLINK_TWO_SIDED"
317+ ignore_allreduce = self .alltoall_method_type == AlltoallMethodType .NVLinkTwoSided
333318 self ._load_balancer_update_statistic (
334319 token_selected_experts ,
335320 is_first_call ,
@@ -440,7 +425,7 @@ def forward_chunk(
440425 token_final_scales = torch .ones_like (token_selected_slots ,
441426 dtype = torch .float32 )
442427
443- if self .moe_alltoall_backend == "NVLINK_TWO_SIDED" :
428+ if self .alltoall_method_type == AlltoallMethodType . NVLinkTwoSided :
444429 assert self .alltoall_prepare_workspace is not None , "alltoall_prepare_workspace should be initialized"
445430 if is_last_call :
446431 loadbalancer_local_statistic_info = self ._load_balancer_get_local_statistic_tensor (
@@ -473,7 +458,7 @@ def forward_chunk(
473458 token_selected_slots , alltoall_info .recv_rank_count_cumsum ,
474459 runtime_max_tokens_per_rank , top_k , self .num_slots ,
475460 self .ep_size )
476- elif self .moe_alltoall_backend == "NVLINK_ONE_SIDED" :
461+ elif self .alltoall_method_type == AlltoallMethodType . NVLinkOneSided :
477462 # Python MoeAlltoAll path
478463 if x_sf is not None :
479464 x_sf = x_sf .view (x_row ,
@@ -511,7 +496,7 @@ def forward_chunk(
511496 - 1 , token_final_scales_recv .shape [- 1 ])
512497 else :
513498 raise ValueError (
514- f"Unsupported moe alltoall backend : { self .moe_alltoall_backend } "
499+ f"Unsupported moe alltoall method type : { self .alltoall_method_type } "
515500 )
516501
517502 elif run_post_quant_allgather :
@@ -533,7 +518,7 @@ def forward_chunk(
533518
534519 # Optionally provide an output tensor to fused_moe so it writes directly to our buffer
535520 moe_output : Optional [torch .Tensor ] = None
536- if self .enable_alltoall and self . moe_alltoall_backend == "NVLINK_ONE_SIDED" :
521+ if self .alltoall_method_type == AlltoallMethodType . NVLinkOneSided :
537522 # Retrieve a workspace-backed output tensor sized by runtime tokens
538523 runtime_max_tokens_per_rank = max (
539524 all_rank_num_tokens ) if all_rank_num_tokens else x .shape [0 ]
@@ -584,7 +569,7 @@ def forward_chunk(
584569
585570 # Combine results if using alltoall
586571 if self .enable_alltoall :
587- if self .moe_alltoall_backend == "NVLINK_TWO_SIDED" :
572+ if self .alltoall_method_type == AlltoallMethodType . NVLinkTwoSided :
588573 if alltoall_info is not None :
589574 top_k = self .routing_method .experts_per_token
590575 final_hidden_states = MnnvlMoe .mnnvl_moe_alltoallv_combine (
@@ -597,7 +582,7 @@ def forward_chunk(
597582 use_low_precision_combine = self .
598583 use_low_precision_combine ,
599584 token_count = token_count )
600- elif self .moe_alltoall_backend == "NVLINK_ONE_SIDED" :
585+ elif self .alltoall_method_type == AlltoallMethodType . NVLinkOneSided :
601586 output_hidden_size = final_hidden_states .shape [- 1 ]
602587 runtime_max_tokens_per_rank = max (
603588 all_rank_num_tokens ) if all_rank_num_tokens else token_count
@@ -609,7 +594,7 @@ def forward_chunk(
609594 payload_in_workspace = True )
610595 else :
611596 raise ValueError (
612- f"Unsupported moe alltoall backend : { self .moe_alltoall_backend } "
597+ f"Unsupported moe alltoall method type : { self .alltoall_method_type } "
613598 )
614599
615600 self ._load_balancer_done_set_cpu_stage (is_last_call )
@@ -709,7 +694,10 @@ def _reducescatter_or_allreduce(x_, idx):
709694 # Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap
710695 for idx_chunk , (x , router_logits ) in enumerate (
711696 zip (x_list , router_logits_list )):
712- if not (self .alltoall_method_type == AlltoallMethodType .MNNVL ):
697+ if not (self .alltoall_method_type
698+ == AlltoallMethodType .NVLinkOneSided
699+ or self .alltoall_method_type
700+ == AlltoallMethodType .NVLinkTwoSided ):
713701 if idx_chunk % 2 == 0 :
714702 with torch .cuda .stream (self .aux_stream ):
715703 outputs = _forward_chunk (x , router_logits ,
@@ -727,7 +715,10 @@ def _reducescatter_or_allreduce(x_, idx):
727715
728716 outputs_list .append (outputs )
729717
730- if not (self .alltoall_method_type == AlltoallMethodType .MNNVL ):
718+ if not (self .alltoall_method_type
719+ == AlltoallMethodType .NVLinkOneSided
720+ or self .alltoall_method_type
721+ == AlltoallMethodType .NVLinkTwoSided ):
731722 if num_chunks % 2 == 0 :
732723 outputs_list [- 1 ] = _reducescatter_or_allreduce (
733724 outputs_list [- 1 ], - 1 )
0 commit comments