Skip to content

Commit dc917a6

Browse files
authored
Merge branch 'dev' into jingqiny/feature-mHC
2 parents fc861ab + f983b21 commit dc917a6

File tree

22 files changed

+1162
-741
lines changed

22 files changed

+1162
-741
lines changed

.github/CODEOWNERS

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
* @NVIDIA/core-nemo @NVIDIA/core-devtech
22

3-
megatron/core/transformer/cuda_graphs.py @NVIDIA/core-adlr @NVIDIA/core-nemo @NVIDIA/cuda-graphs
4-
53
.gitlab/ @NVIDIA/ci
64
.github/ @NVIDIA/ci
75
.gitlab-ci.yml @NVIDIA/ci

megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,11 @@ def validate_uneven_dtensor(dtensor: DTensor) -> None:
175175
)
176176

177177
# Check that all boundaries (start and end) are touched.
178+
# Skip under fake process group — all_reduce is a no-op so only rank 0's
179+
# boundaries are visible, which makes the end-boundary check always fail.
180+
if torch.distributed.is_initialized() and torch.distributed.get_backend() == 'fake':
181+
return
182+
178183
boundary_checks = torch.tensor(
179184
[
180185
[offset == 0, offset + size == dtensor.shape[dim]]

megatron/core/optimizer/__init__.py

Lines changed: 179 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import copy
33
import logging
44
import warnings
5+
from collections import defaultdict
56
from dataclasses import astuple
67
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
78

@@ -47,14 +48,22 @@
4748
from ..transformer.module import MegatronModule
4849
from ..utils import get_model_config, get_pg_rank, get_pg_size, is_te_min_version, log_single_rank
4950
from .distrib_optimizer import DistributedOptimizer
51+
from .emerging_optimizers import (
52+
_EMERGING_OPTIMIZERS,
53+
HAVE_EMERGING_OPTIMIZERS,
54+
_create_emerging_optimizer,
55+
)
5056
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
57+
from .layer_wise_optimizer import LayerWiseDistributedOptimizer
5158
from .optimizer import (
5259
ChainedOptimizer,
5360
Float16OptimizerWithFloat16Params,
5461
FP32Optimizer,
5562
MegatronOptimizer,
5663
param_group_identifier_keys,
5764
)
65+
66+
# Subclass aliases kept for backward compatibility; all are OptimizerConfig.
5867
from .optimizer_config import (
5968
AdamOptimizerConfig,
6069
OptimizerConfig,
@@ -134,14 +143,6 @@ def _get_param_groups(
134143
# Map (pg_overrides, is_expert_parallel) to params.
135144
params_map = {}
136145

137-
if config_overrides is None:
138-
# TODO remove this default behavior eventually.
139-
# This is only needed for backwards compatibility with the old config overrides API where
140-
# the config_overrides argument by default lead to bias parameters and length 1 parameters.
141-
# We assume that users of decoupled LR already provide config overrides so will adapt
142-
# to the new API.
143-
config_overrides = get_standard_config_overrides(config=config)
144-
145146
for model_chunk in model_chunks:
146147
for name, param in model_chunk.named_parameters():
147148
if not param.requires_grad:
@@ -276,7 +277,8 @@ def _get_megatron_optimizer_based_on_param_groups(
276277
intra_dist_opt_group: Optional[torch.distributed.ProcessGroup] = None,
277278
distributed_optimizer_instance_id: Optional[int] = 0,
278279
pg_collection: Optional[ProcessGroupCollection] = None,
279-
) -> MegatronOptimizer:
280+
skip_megatron_wrapping: bool = False,
281+
) -> Union[MegatronOptimizer, Tuple[Optional[torch.optim.Optimizer], Optional[Callable]]]:
280282
"""Get Megatron optimizer based on parameter groups.
281283
282284
Args:
@@ -292,12 +294,24 @@ def _get_megatron_optimizer_based_on_param_groups(
292294
optimizer. Defaults to None.
293295
distributed_optimizer_instance_id (int, optional): Distributed optimizer instance. Defaults
294296
0.
297+
skip_megatron_wrapping (bool): if True, return a
298+
``(optimizer, init_state_fn)`` tuple of the raw PyTorch optimizer
299+
without any Megatron wrapping. Useful when the caller
300+
(e.g. LayerWiseDistributedOptimizer) performs its own wrapping.
295301
296302
Returns:
297-
Instance of MegatronOptimizer.
303+
Instance of MegatronOptimizer, or ``(optimizer, init_state_fn)`` when
304+
*skip_megatron_wrapping=True*.
298305
"""
299-
# TODO: Logic needs to be updated to handle different optimizer types (i.e., param_groups
300-
# passed into this function need to correspond to the same optimizer).
306+
# All param_groups passed here must belong to the same optimizer type (adam / sgd).
307+
# Callers are responsible for splitting by optimizer type before calling this function.
308+
309+
if skip_megatron_wrapping and config.use_precision_aware_optimizer:
310+
raise ValueError(
311+
"skip_megatron_wrapping=True is incompatible with use_precision_aware_optimizer."
312+
)
313+
if skip_megatron_wrapping and config.optimizer_cpu_offload:
314+
raise ValueError("skip_megatron_wrapping=True is incompatible with optimizer_cpu_offload.")
301315

302316
# When freezing sub-models we may have no trainable parameters on a rank and
303317
# hence an empty param_groups. However, we still need to create an optimizer
@@ -412,6 +426,9 @@ def init_state_fn(opt, config=None):
412426
optimizer = None
413427
init_state_fn = None
414428

429+
if skip_megatron_wrapping:
430+
return optimizer, init_state_fn
431+
415432
# Mixed precision optimizer.
416433
# - Note: both the Float16Optimizer and the DistributedOptimizer inherit
417434
# from the MixedPrecisionOptimizer, which manages any optimizer where
@@ -502,6 +519,137 @@ def check_config_overrides_consistency(
502519
return True
503520

504521

522+
def _get_megatron_emerging_optimizer(
523+
config: OptimizerConfig,
524+
model_chunks: List[MegatronModule],
525+
config_overrides: Optional[Dict[ParamKey, Any]] = None,
526+
pg_collection: Optional[ProcessGroupCollection] = None,
527+
) -> MegatronOptimizer:
528+
"""Build an emerging optimizer (e.g. Muon) for the given model chunks.
529+
530+
Parameter separation (e.g., linear weights -> Muon, rest -> Adam) is expressed as a
531+
config_override, the same mechanism used for weight-decay and learning-rate overrides.
532+
Adam/SGD groups are delegated to _get_megatron_optimizer_based_on_param_groups so they
533+
go through the exact same code path as the standard optimizer factory.
534+
535+
When ``config.use_layer_wise_distributed_optimizer`` is True, the underlying optimizers
536+
are wrapped with :class:`LayerWiseDistributedOptimizer`.
537+
"""
538+
eopt_name = config.optimizer
539+
use_layer_wise = config.use_layer_wise_distributed_optimizer
540+
541+
# Handle legacy "dist_*" optimizer names (e.g. "dist_muon" → "muon" + layer-wise).
542+
if eopt_name.startswith('dist_'):
543+
bare_name = eopt_name[len('dist_') :]
544+
warnings.warn(
545+
f"optimizer='{eopt_name}' is deprecated. "
546+
f"Use optimizer='{bare_name}' with use_layer_wise_distributed_optimizer=True.",
547+
DeprecationWarning,
548+
stacklevel=3,
549+
)
550+
eopt_name = bare_name
551+
use_layer_wise = True
552+
553+
if not HAVE_EMERGING_OPTIMIZERS:
554+
raise ImportError(
555+
f"emerging-optimizers package is required for optimizer='{eopt_name}'. "
556+
"Install it with: pip install emerging-optimizers"
557+
)
558+
if eopt_name not in _EMERGING_OPTIMIZERS:
559+
raise ValueError(f"Unsupported emerging optimizer: {eopt_name}")
560+
if config.fp16:
561+
raise ValueError('emerging optimizer with fp16 is not supported.')
562+
563+
if pg_collection is None:
564+
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
565+
566+
log_single_rank(logger, logging.INFO, f'Setting up emerging optimizer with config {config}')
567+
568+
# Tag parameters with optimizer-specific attributes (expert_tp, is_qkv).
569+
for model_chunk in model_chunks:
570+
for name, param in model_chunk.named_parameters():
571+
if not param.requires_grad:
572+
continue
573+
if 'experts' in name and 'shared' not in name:
574+
param.expert_tp = True
575+
# TODO(deyuf): support MLA
576+
if 'linear_qkv.weight' in name and len(param.shape) == 2:
577+
param.is_qkv = True
578+
579+
# Apply optimizer-specific default param overrides (e.g. muon: non-linear -> adam).
580+
config_overrides.update(_EMERGING_OPTIMIZERS[eopt_name].default_param_overrides)
581+
582+
# Build param groups and bucket by (optimizer_name, is_expert_parallel).
583+
# Layer-wise distributed optimizer handles expert params internally so we skip that split.
584+
all_param_groups = _get_param_groups(model_chunks, config, config_overrides)
585+
grouped_param_groups = defaultdict(list)
586+
for group in all_param_groups:
587+
opt_name = group.get('optimizer', eopt_name)
588+
is_expert = group['is_expert_parallel'] and not use_layer_wise
589+
grouped_param_groups[(opt_name, is_expert)].append(group)
590+
591+
# Build an optimizer for each (optimizer_name, is_expert) bucket and combine.
592+
results = []
593+
for (opt_name, is_expert), groups in grouped_param_groups.items():
594+
if not groups:
595+
continue
596+
597+
model_parallel_group = pg_collection.tp_ep_pp if is_expert else pg_collection.mp
598+
599+
if opt_name in _EMERGING_OPTIMIZERS:
600+
optimizer, init_state_fn = _create_emerging_optimizer(
601+
config, groups, eopt_name, model_chunks, pg_collection
602+
)
603+
if use_layer_wise:
604+
result = (optimizer, init_state_fn)
605+
else:
606+
if config.bf16:
607+
optimizer = Float16OptimizerWithFloat16Params(
608+
optimizer, config, None, init_state_fn
609+
)
610+
else:
611+
optimizer = FP32Optimizer(optimizer, config, init_state_fn)
612+
setattr(optimizer, 'grad_stats_parallel_group', model_parallel_group)
613+
if pg_collection is None or not hasattr(pg_collection, 'tp'):
614+
tp_group = parallel_state.get_tensor_model_parallel_group()
615+
else:
616+
tp_group = pg_collection.tp
617+
setattr(optimizer, 'tp_group', tp_group)
618+
result = optimizer
619+
else:
620+
fallback_config = copy.copy(config)
621+
fallback_config.optimizer = opt_name
622+
fallback_config.use_distributed_optimizer = False
623+
result = _get_megatron_optimizer_based_on_param_groups(
624+
config=fallback_config,
625+
model_chunks=model_chunks,
626+
param_groups=groups,
627+
model_parallel_group=model_parallel_group,
628+
pg_collection=pg_collection,
629+
skip_megatron_wrapping=use_layer_wise,
630+
)
631+
# TODO(deyuf): ChainedOptimizer currently asserts all sub-optimizers
632+
# share the same config. Revisit this design now that emerging
633+
# optimizers mix different optimizer types (e.g. Muon + Adam).
634+
# For now, reset to the top-level config so the assertion holds.
635+
if not use_layer_wise and hasattr(result, 'config'):
636+
result.config = config
637+
results.append(result)
638+
639+
if use_layer_wise:
640+
base_optimizers, init_fns = (), ()
641+
if results:
642+
base_optimizers, init_fns = zip(*results)
643+
log_single_rank(
644+
logger, logging.INFO, f'Using LayerWiseDistributedOptimizer for {eopt_name}'
645+
)
646+
return LayerWiseDistributedOptimizer(
647+
list(base_optimizers), config, pg_collection, init_state_fn_list=list(init_fns)
648+
)
649+
650+
return ChainedOptimizer(results)
651+
652+
505653
def get_megatron_optimizer(
506654
config: OptimizerConfig,
507655
model_chunks: List[MegatronModule],
@@ -512,7 +660,10 @@ def get_megatron_optimizer(
512660
) -> MegatronOptimizer:
513661
"""Retrieve the Megatron optimizer for model chunks.
514662
663+
Handles both standard optimizers (Adam, SGD) and emerging optimizers (e.g. Muon).
515664
We use separate optimizers for expert parameters and non-expert parameters.
665+
For emerging optimizers with ``config.use_layer_wise_distributed_optimizer=True``,
666+
the optimizer is automatically wrapped with :class:`LayerWiseDistributedOptimizer`.
516667
517668
Args:
518669
config (OptimizerConfig): optimizer configuration object.
@@ -529,10 +680,25 @@ def get_megatron_optimizer(
529680
Instance of MegatronOptimizer.
530681
"""
531682

532-
log_single_rank(logger, logging.INFO, f'Setting up optimizer with config {config}')
683+
# None → apply standard defaults. To extend defaults with custom overrides,
684+
# start from get_standard_config_overrides(config) and merge yours in.
685+
if config_overrides is None:
686+
config_overrides = get_standard_config_overrides(config)
533687

534688
check_config_overrides_consistency(config, config_overrides)
535689

690+
# TODO: the standard and emerging optimizer paths handle pg_collection differently;
691+
# unify them so both use a single pg_collection-based flow.
692+
if config.optimizer not in ('adam', 'sgd'):
693+
return _get_megatron_emerging_optimizer(
694+
config=config,
695+
model_chunks=model_chunks,
696+
config_overrides=config_overrides,
697+
pg_collection=pg_collection,
698+
)
699+
700+
log_single_rank(logger, logging.INFO, f'Setting up optimizer with config {config}')
701+
536702
# Separate out first model chunk if overlapping param AG with optimizer step.
537703
if config.overlap_param_gather_with_optimizer_step:
538704
all_dense_model_chunks = [[model_chunks[0]], model_chunks[1:]]

0 commit comments

Comments
 (0)