@@ -211,6 +211,10 @@ def __init__(
211211 if not model_config .skip_create_weights_in_init :
212212 self .create_weights ()
213213
214+ # Debug function for eliminating imbalance during performance analysis.
215+ self .enable_dummy_allreduce = os .environ .get (
216+ "TRTLLM_ENABLE_DUMMY_ALLREDUCE" , "0" ) == "1"
217+
214218 def _check_configs (self ):
215219 assert self ._weights_created
216220
@@ -302,6 +306,16 @@ def create_weights(self):
302306 self ._weights_created = True
303307 self ._check_configs ()
304308
309+ def dummy_allreduce (self ):
310+ """
311+ Debug function for eliminating imbalance during performance analysis.
312+ Creates a small dummy tensor and performs allreduce to synchronize processes
313+ and eliminate timing imbalances for more accurate profiling measurements.
314+ """
315+ dummy_tensor = torch .zeros (4 , dtype = torch .float32 , device = 'cuda' )
316+ dummy_tensor = self .all_reduce (dummy_tensor )
317+ return dummy_tensor
318+
305319 def reducescatter_or_allreduce (
306320 self ,
307321 inputs ,
@@ -311,6 +325,8 @@ def reducescatter_or_allreduce(
311325 outputs = inputs
312326 if self .parallel_size > 1 and not self .enable_alltoall :
313327 if self .use_dp :
328+ if self .enable_dummy_allreduce :
329+ self .dummy_allreduce ()
314330 outputs = reducescatter (
315331 inputs ,
316332 self .mapping ,
@@ -398,6 +414,8 @@ def forward_chunk(
398414
399415 if self .enable_alltoall :
400416 if self .alltoall_method_type == AlltoallMethodType .MNNVL :
417+ if self .enable_dummy_allreduce :
418+ self .dummy_allreduce ()
401419 token_count = x .shape [0 ]
402420 alltoall_info = None
403421 x , token_selected_slots , token_final_scales , gathered_loadbalancer_local_statistic_info , alltoall_info = \
@@ -482,6 +500,8 @@ def forward_chunk(
482500
483501 if self .use_dp and self .parallel_size > 1 and not disable_fp4_allgather (
484502 ) and not self .enable_alltoall :
503+ if self .enable_dummy_allreduce :
504+ self .dummy_allreduce ()
485505 x , x_sf , token_selected_slots , token_final_scales , gathered_token_selected_experts_for_statistic = allgather (
486506 [
487507 x ,
@@ -630,6 +650,8 @@ def forward_chunk(
630650
631651 if self .enable_alltoall :
632652 if self .alltoall_method_type == AlltoallMethodType .MNNVL :
653+ if self .enable_dummy_allreduce :
654+ self .dummy_allreduce ()
633655 final_hidden_states = self .alltoall_combine (
634656 final_hidden_states , alltoall_info , token_count )
635657 elif self .alltoall_method_type == AlltoallMethodType .DeepEP :
0 commit comments