Skip to content

Commit e08e0b9

Browse files
nanz-nvvasunvidia
authored andcommitted
make overload factor logging work for cuda graph
1 parent 2e4db72 commit e08e0b9

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

megatron/core/transformer/moe/moe_utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,15 @@ def save_overload_factor_to_tracker(
10031003
"""
10041004
# Set comm groups in tracker (outside autograd function)
10051005
tracker = get_overload_factor_tracker()
1006+
if "to_clear" in tracker and tracker["to_clear"]:
1007+
if "fwd" in tracker:
1008+
tracker["fwd"].clear()
1009+
if "fwd_bwd" in tracker:
1010+
tracker["fwd_bwd"].clear()
1011+
tracker.pop("tp_ep_group", None)
1012+
tracker.pop("dp_group", None)
1013+
tracker.pop("to_clear", None)
1014+
10061015
if "tp_ep_group" not in tracker:
10071016
tracker["tp_ep_group"] = tp_ep_group
10081017
tracker["dp_group"] = dp_group
@@ -1050,7 +1059,6 @@ def get_overload_factors_for_logging() -> dict:
10501059
# Stack fwd_bwd tensors (already has fwd positive, bwd negative)
10511060
fwd_bwd_tensors = tracker.get("fwd_bwd", [])
10521061
fwd_bwd_tensors_stacked = torch.stack(fwd_bwd_tensors, dim=0) if fwd_bwd_tensors else None
1053-
10541062
# All-reduce across tp_ep_group, cumsum, and find max
10551063
max_cum_overload_factor = None
10561064
if fwd_bwd_tensors_stacked is not None:
@@ -1120,12 +1128,7 @@ def get_overload_factors_for_logging() -> dict:
11201128
def clear_overload_factor_tracker():
11211129
"""Clear the overload factor tracker."""
11221130
tracker = get_overload_factor_tracker()
1123-
if "fwd" in tracker:
1124-
tracker["fwd"].clear()
1125-
if "fwd_bwd" in tracker:
1126-
tracker["fwd_bwd"].clear()
1127-
tracker.pop("tp_ep_group", None)
1128-
tracker.pop("dp_group", None)
1131+
tracker["to_clear"] = True
11291132

11301133

11311134
def reduce_aux_losses_tracker_across_ranks(

megatron/core/transformer/moe/router.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,7 @@ def forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = No
667667
logits, self.config.moe_router_force_biased, self.layer_number
668668
)
669669

670-
probs, routing_map = self.routing(logits, padding_mask=padding_mask)
670+
probs, routing_map = self.routing(logits)
671671
# Log overload factor if enabled
672672
if self.config.log_overload_factor:
673673
# Compute num_local_experts from config and EP size

0 commit comments

Comments
 (0)