@@ -819,13 +819,12 @@ def create_fwd_graph(self, args, kwargs, outputs=None, clone_inputs=True):
819819
820820 is_moe = isinstance (self .base_module , MoETransformerLayer )
821821 if is_moe :
822- from megatron .core .transformer .moe .moe_utils import get_moe_layer_wise_logging_tracker
822+ from megatron .core .transformer .moe .moe_logging import get_moe_metrics_tracker
823823
824- tracker = get_moe_layer_wise_logging_tracker ()
824+ moe_metrics_tracker = get_moe_metrics_tracker ()
825825 cached_aux_losses = {}
826- for name in tracker :
827- if "values" in tracker [name ]:
828- cached_aux_losses [name ] = torch .clone (tracker [name ]["values" ])
826+ for name , entry in moe_metrics_tracker .metrics .items ():
827+ cached_aux_losses [name ] = entry .values .clone ()
829828
830829 self .fwd_graph = torch .cuda .CUDAGraph ()
831830
@@ -1014,8 +1013,11 @@ def clone_ten(ten):
10141013 param .main_grad .copy_ (main_grad_copy )
10151014
10161015 if is_moe :
1017- for name in tracker :
1018- tracker [name ]["values" ].copy_ (cached_aux_losses [name ])
1016+ for name , cached_values in cached_aux_losses .items ():
1017+ assert (
1018+ name in moe_metrics_tracker .metrics
1019+ ), "cached metrics must be found in the tracker."
1020+ moe_metrics_tracker .metrics [name ].values .copy_ (cached_values )
10191021
10201022 def create_bwd_graph (self ):
10211023 """Create a bwd cudagraph for this runner. Should be called inside
@@ -2208,14 +2210,15 @@ def _finish_capturing(self, start_time):
22082210 _set_capture_end ()
22092211
22102212 from megatron .core .distributed .finalize_model_grads import reset_model_temporary_tensors
2211- from megatron .core .transformer .moe .moe_utils import clear_aux_losses_tracker
22122213
22132214 torch .distributed .barrier ()
22142215 for model_chunk in self .model :
22152216 model_chunk .zero_grad_buffer ()
22162217 for optimizer in self .optimizers :
22172218 optimizer .zero_grad ()
2218- clear_aux_losses_tracker ()
2219+ from megatron .core .transformer .moe .moe_logging import get_moe_metrics_tracker
2220+
2221+ get_moe_metrics_tracker ().clear ()
22192222 reset_model_temporary_tensors (self .config , self .model )
22202223
22212224 if FREEZE_GC :
0 commit comments