@@ -1003,6 +1003,15 @@ def save_overload_factor_to_tracker(
10031003 """
10041004 # Set comm groups in tracker (outside autograd function)
10051005 tracker = get_overload_factor_tracker ()
1006+ if "to_clear" in tracker and tracker ["to_clear" ]:
1007+ if "fwd" in tracker :
1008+ tracker ["fwd" ].clear ()
1009+ if "fwd_bwd" in tracker :
1010+ tracker ["fwd_bwd" ].clear ()
1011+ tracker .pop ("tp_ep_group" , None )
1012+ tracker .pop ("dp_group" , None )
1013+ tracker .pop ("to_clear" , None )
1014+
10061015 if "tp_ep_group" not in tracker :
10071016 tracker ["tp_ep_group" ] = tp_ep_group
10081017 tracker ["dp_group" ] = dp_group
@@ -1050,7 +1059,6 @@ def get_overload_factors_for_logging() -> dict:
10501059 # Stack fwd_bwd tensors (already has fwd positive, bwd negative)
10511060 fwd_bwd_tensors = tracker .get ("fwd_bwd" , [])
10521061 fwd_bwd_tensors_stacked = torch .stack (fwd_bwd_tensors , dim = 0 ) if fwd_bwd_tensors else None
1053-
10541062 # All-reduce across tp_ep_group, cumsum, and find max
10551063 max_cum_overload_factor = None
10561064 if fwd_bwd_tensors_stacked is not None :
@@ -1120,12 +1128,7 @@ def get_overload_factors_for_logging() -> dict:
11201128def clear_overload_factor_tracker ():
11211129 """Clear the overload factor tracker."""
11221130 tracker = get_overload_factor_tracker ()
1123- if "fwd" in tracker :
1124- tracker ["fwd" ].clear ()
1125- if "fwd_bwd" in tracker :
1126- tracker ["fwd_bwd" ].clear ()
1127- tracker .pop ("tp_ep_group" , None )
1128- tracker .pop ("dp_group" , None )
1131+ tracker ["to_clear" ] = True
11291132
11301133
11311134def reduce_aux_losses_tracker_across_ranks (
0 commit comments