Skip to content

Commit 3307414

Browse files
Merge branch 'dev' into yuzhongw/fix_dsa_spec_dev
2 parents 311b567 + f983b21 commit 3307414

File tree

6 files changed

+469
-180
lines changed

6 files changed

+469
-180
lines changed

megatron/core/transformer/cuda_graphs.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -819,13 +819,12 @@ def create_fwd_graph(self, args, kwargs, outputs=None, clone_inputs=True):
819819

820820
is_moe = isinstance(self.base_module, MoETransformerLayer)
821821
if is_moe:
822-
from megatron.core.transformer.moe.moe_utils import get_moe_layer_wise_logging_tracker
822+
from megatron.core.transformer.moe.moe_logging import get_moe_metrics_tracker
823823

824-
tracker = get_moe_layer_wise_logging_tracker()
824+
moe_metrics_tracker = get_moe_metrics_tracker()
825825
cached_aux_losses = {}
826-
for name in tracker:
827-
if "values" in tracker[name]:
828-
cached_aux_losses[name] = torch.clone(tracker[name]["values"])
826+
for name, entry in moe_metrics_tracker.metrics.items():
827+
cached_aux_losses[name] = entry.values.clone()
829828

830829
self.fwd_graph = torch.cuda.CUDAGraph()
831830

@@ -1014,8 +1013,11 @@ def clone_ten(ten):
10141013
param.main_grad.copy_(main_grad_copy)
10151014

10161015
if is_moe:
1017-
for name in tracker:
1018-
tracker[name]["values"].copy_(cached_aux_losses[name])
1016+
for name, cached_values in cached_aux_losses.items():
1017+
assert (
1018+
name in moe_metrics_tracker.metrics
1019+
), "cached metrics must be found in the tracker."
1020+
moe_metrics_tracker.metrics[name].values.copy_(cached_values)
10191021

10201022
def create_bwd_graph(self):
10211023
"""Create a bwd cudagraph for this runner. Should be called inside
@@ -2208,14 +2210,15 @@ def _finish_capturing(self, start_time):
22082210
_set_capture_end()
22092211

22102212
from megatron.core.distributed.finalize_model_grads import reset_model_temporary_tensors
2211-
from megatron.core.transformer.moe.moe_utils import clear_aux_losses_tracker
22122213

22132214
torch.distributed.barrier()
22142215
for model_chunk in self.model:
22152216
model_chunk.zero_grad_buffer()
22162217
for optimizer in self.optimizers:
22172218
optimizer.zero_grad()
2218-
clear_aux_losses_tracker()
2219+
from megatron.core.transformer.moe.moe_logging import get_moe_metrics_tracker
2220+
2221+
get_moe_metrics_tracker().clear()
22192222
reset_model_temporary_tensors(self.config, self.model)
22202223

22212224
if FREEZE_GC:

0 commit comments

Comments
 (0)