Skip to content

Commit 089fd55

Browse files
authored
Add dummy all_reduce for kernel breakdown (NVIDIA#5745)
Signed-off-by: Xianjie <[email protected]>
1 parent 1b588f8 commit 089fd55

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,10 @@ def __init__(
211211
if not model_config.skip_create_weights_in_init:
212212
self.create_weights()
213213

214+
# Debug function for eliminating imbalance during performance analysis.
215+
self.enable_dummy_allreduce = os.environ.get(
216+
"TRTLLM_ENABLE_DUMMY_ALLREDUCE", "0") == "1"
217+
214218
def _check_configs(self):
215219
assert self._weights_created
216220

@@ -302,6 +306,16 @@ def create_weights(self):
302306
self._weights_created = True
303307
self._check_configs()
304308

309+
def dummy_allreduce(self):
310+
"""
311+
Debug function for eliminating imbalance during performance analysis.
312+
Creates a small dummy tensor and performs allreduce to synchronize processes
313+
and eliminate timing imbalances for more accurate profiling measurements.
314+
"""
315+
dummy_tensor = torch.zeros(4, dtype=torch.float32, device='cuda')
316+
dummy_tensor = self.all_reduce(dummy_tensor)
317+
return dummy_tensor
318+
305319
def reducescatter_or_allreduce(
306320
self,
307321
inputs,
@@ -311,6 +325,8 @@ def reducescatter_or_allreduce(
311325
outputs = inputs
312326
if self.parallel_size > 1 and not self.enable_alltoall:
313327
if self.use_dp:
328+
if self.enable_dummy_allreduce:
329+
self.dummy_allreduce()
314330
outputs = reducescatter(
315331
inputs,
316332
self.mapping,
@@ -398,6 +414,8 @@ def forward_chunk(
398414

399415
if self.enable_alltoall:
400416
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
417+
if self.enable_dummy_allreduce:
418+
self.dummy_allreduce()
401419
token_count = x.shape[0]
402420
alltoall_info = None
403421
x, token_selected_slots, token_final_scales, gathered_loadbalancer_local_statistic_info, alltoall_info = \
@@ -482,6 +500,8 @@ def forward_chunk(
482500

483501
if self.use_dp and self.parallel_size > 1 and not disable_fp4_allgather(
484502
) and not self.enable_alltoall:
503+
if self.enable_dummy_allreduce:
504+
self.dummy_allreduce()
485505
x, x_sf, token_selected_slots, token_final_scales, gathered_token_selected_experts_for_statistic = allgather(
486506
[
487507
x,
@@ -630,6 +650,8 @@ def forward_chunk(
630650

631651
if self.enable_alltoall:
632652
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
653+
if self.enable_dummy_allreduce:
654+
self.dummy_allreduce()
633655
final_hidden_states = self.alltoall_combine(
634656
final_hidden_states, alltoall_info, token_count)
635657
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:

0 commit comments

Comments
 (0)