22import copy
33import logging
44import warnings
5+ from collections import defaultdict
56from dataclasses import astuple
67from typing import Any , Callable , Dict , List , Optional , Tuple , Union
78
4748from ..transformer .module import MegatronModule
4849from ..utils import get_model_config , get_pg_rank , get_pg_size , is_te_min_version , log_single_rank
4950from .distrib_optimizer import DistributedOptimizer
51+ from .emerging_optimizers import (
52+ _EMERGING_OPTIMIZERS ,
53+ HAVE_EMERGING_OPTIMIZERS ,
54+ _create_emerging_optimizer ,
55+ )
5056from .grad_scaler import ConstantGradScaler , DynamicGradScaler
57+ from .layer_wise_optimizer import LayerWiseDistributedOptimizer
5158from .optimizer import (
5259 ChainedOptimizer ,
5360 Float16OptimizerWithFloat16Params ,
5461 FP32Optimizer ,
5562 MegatronOptimizer ,
5663 param_group_identifier_keys ,
5764)
65+
66+ # Subclass aliases kept for backward compatibility; all are OptimizerConfig.
5867from .optimizer_config import (
5968 AdamOptimizerConfig ,
6069 OptimizerConfig ,
@@ -134,14 +143,6 @@ def _get_param_groups(
134143 # Map (pg_overrides, is_expert_parallel) to params.
135144 params_map = {}
136145
137- if config_overrides is None :
138- # TODO remove this default behavior eventually.
139- # This is only needed for backwards compatibility with the old config overrides API where
140- # the config_overrides argument by default lead to bias parameters and length 1 parameters.
141- # We assume that users of decoupled LR already provide config overrides so will adapt
142- # to the new API.
143- config_overrides = get_standard_config_overrides (config = config )
144-
145146 for model_chunk in model_chunks :
146147 for name , param in model_chunk .named_parameters ():
147148 if not param .requires_grad :
@@ -276,7 +277,8 @@ def _get_megatron_optimizer_based_on_param_groups(
276277 intra_dist_opt_group : Optional [torch .distributed .ProcessGroup ] = None ,
277278 distributed_optimizer_instance_id : Optional [int ] = 0 ,
278279 pg_collection : Optional [ProcessGroupCollection ] = None ,
279- ) -> MegatronOptimizer :
280+ skip_megatron_wrapping : bool = False ,
281+ ) -> Union [MegatronOptimizer , Tuple [Optional [torch .optim .Optimizer ], Optional [Callable ]]]:
280282 """Get Megatron optimizer based on parameter groups.
281283
282284 Args:
@@ -292,12 +294,24 @@ def _get_megatron_optimizer_based_on_param_groups(
292294 optimizer. Defaults to None.
293295 distributed_optimizer_instance_id (int, optional): Distributed optimizer instance. Defaults
294296 0.
297+ skip_megatron_wrapping (bool): if True, return a
298+ ``(optimizer, init_state_fn)`` tuple of the raw PyTorch optimizer
299+ without any Megatron wrapping. Useful when the caller
300+ (e.g. LayerWiseDistributedOptimizer) performs its own wrapping.
295301
296302 Returns:
297- Instance of MegatronOptimizer.
303+ Instance of MegatronOptimizer, or ``(optimizer, init_state_fn)`` when
304+ *skip_megatron_wrapping=True*.
298305 """
299- # TODO: Logic needs to be updated to handle different optimizer types (i.e., param_groups
300- # passed into this function need to correspond to the same optimizer).
306+ # All param_groups passed here must belong to the same optimizer type (adam / sgd).
307+ # Callers are responsible for splitting by optimizer type before calling this function.
308+
309+ if skip_megatron_wrapping and config .use_precision_aware_optimizer :
310+ raise ValueError (
311+ "skip_megatron_wrapping=True is incompatible with use_precision_aware_optimizer."
312+ )
313+ if skip_megatron_wrapping and config .optimizer_cpu_offload :
314+ raise ValueError ("skip_megatron_wrapping=True is incompatible with optimizer_cpu_offload." )
301315
302316 # When freezing sub-models we may have no trainable parameters on a rank and
303317 # hence an empty param_groups. However, we still need to create an optimizer
@@ -412,6 +426,9 @@ def init_state_fn(opt, config=None):
412426 optimizer = None
413427 init_state_fn = None
414428
429+ if skip_megatron_wrapping :
430+ return optimizer , init_state_fn
431+
415432 # Mixed precision optimizer.
416433 # - Note: both the Float16Optimizer and the DistributedOptimizer inherit
417434 # from the MixedPrecisionOptimizer, which manages any optimizer where
@@ -502,6 +519,137 @@ def check_config_overrides_consistency(
502519 return True
503520
504521
522+ def _get_megatron_emerging_optimizer (
523+ config : OptimizerConfig ,
524+ model_chunks : List [MegatronModule ],
525+ config_overrides : Optional [Dict [ParamKey , Any ]] = None ,
526+ pg_collection : Optional [ProcessGroupCollection ] = None ,
527+ ) -> MegatronOptimizer :
528+ """Build an emerging optimizer (e.g. Muon) for the given model chunks.
529+
530+ Parameter separation (e.g., linear weights -> Muon, rest -> Adam) is expressed as a
531+ config_override, the same mechanism used for weight-decay and learning-rate overrides.
532+ Adam/SGD groups are delegated to _get_megatron_optimizer_based_on_param_groups so they
533+ go through the exact same code path as the standard optimizer factory.
534+
535+ When ``config.use_layer_wise_distributed_optimizer`` is True, the underlying optimizers
536+ are wrapped with :class:`LayerWiseDistributedOptimizer`.
537+ """
538+ eopt_name = config .optimizer
539+ use_layer_wise = config .use_layer_wise_distributed_optimizer
540+
541+ # Handle legacy "dist_*" optimizer names (e.g. "dist_muon" → "muon" + layer-wise).
542+ if eopt_name .startswith ('dist_' ):
543+ bare_name = eopt_name [len ('dist_' ) :]
544+ warnings .warn (
545+ f"optimizer='{ eopt_name } ' is deprecated. "
546+ f"Use optimizer='{ bare_name } ' with use_layer_wise_distributed_optimizer=True." ,
547+ DeprecationWarning ,
548+ stacklevel = 3 ,
549+ )
550+ eopt_name = bare_name
551+ use_layer_wise = True
552+
553+ if not HAVE_EMERGING_OPTIMIZERS :
554+ raise ImportError (
555+ f"emerging-optimizers package is required for optimizer='{ eopt_name } '. "
556+ "Install it with: pip install emerging-optimizers"
557+ )
558+ if eopt_name not in _EMERGING_OPTIMIZERS :
559+ raise ValueError (f"Unsupported emerging optimizer: { eopt_name } " )
560+ if config .fp16 :
561+ raise ValueError ('emerging optimizer with fp16 is not supported.' )
562+
563+ if pg_collection is None :
564+ pg_collection = ProcessGroupCollection .use_mpu_process_groups ()
565+
566+ log_single_rank (logger , logging .INFO , f'Setting up emerging optimizer with config { config } ' )
567+
568+ # Tag parameters with optimizer-specific attributes (expert_tp, is_qkv).
569+ for model_chunk in model_chunks :
570+ for name , param in model_chunk .named_parameters ():
571+ if not param .requires_grad :
572+ continue
573+ if 'experts' in name and 'shared' not in name :
574+ param .expert_tp = True
575+ # TODO(deyuf): support MLA
576+ if 'linear_qkv.weight' in name and len (param .shape ) == 2 :
577+ param .is_qkv = True
578+
579+ # Apply optimizer-specific default param overrides (e.g. muon: non-linear -> adam).
580+ config_overrides .update (_EMERGING_OPTIMIZERS [eopt_name ].default_param_overrides )
581+
582+ # Build param groups and bucket by (optimizer_name, is_expert_parallel).
583+ # Layer-wise distributed optimizer handles expert params internally so we skip that split.
584+ all_param_groups = _get_param_groups (model_chunks , config , config_overrides )
585+ grouped_param_groups = defaultdict (list )
586+ for group in all_param_groups :
587+ opt_name = group .get ('optimizer' , eopt_name )
588+ is_expert = group ['is_expert_parallel' ] and not use_layer_wise
589+ grouped_param_groups [(opt_name , is_expert )].append (group )
590+
591+ # Build an optimizer for each (optimizer_name, is_expert) bucket and combine.
592+ results = []
593+ for (opt_name , is_expert ), groups in grouped_param_groups .items ():
594+ if not groups :
595+ continue
596+
597+ model_parallel_group = pg_collection .tp_ep_pp if is_expert else pg_collection .mp
598+
599+ if opt_name in _EMERGING_OPTIMIZERS :
600+ optimizer , init_state_fn = _create_emerging_optimizer (
601+ config , groups , eopt_name , model_chunks , pg_collection
602+ )
603+ if use_layer_wise :
604+ result = (optimizer , init_state_fn )
605+ else :
606+ if config .bf16 :
607+ optimizer = Float16OptimizerWithFloat16Params (
608+ optimizer , config , None , init_state_fn
609+ )
610+ else :
611+ optimizer = FP32Optimizer (optimizer , config , init_state_fn )
612+ setattr (optimizer , 'grad_stats_parallel_group' , model_parallel_group )
613+ if pg_collection is None or not hasattr (pg_collection , 'tp' ):
614+ tp_group = parallel_state .get_tensor_model_parallel_group ()
615+ else :
616+ tp_group = pg_collection .tp
617+ setattr (optimizer , 'tp_group' , tp_group )
618+ result = optimizer
619+ else :
620+ fallback_config = copy .copy (config )
621+ fallback_config .optimizer = opt_name
622+ fallback_config .use_distributed_optimizer = False
623+ result = _get_megatron_optimizer_based_on_param_groups (
624+ config = fallback_config ,
625+ model_chunks = model_chunks ,
626+ param_groups = groups ,
627+ model_parallel_group = model_parallel_group ,
628+ pg_collection = pg_collection ,
629+ skip_megatron_wrapping = use_layer_wise ,
630+ )
631+ # TODO(deyuf): ChainedOptimizer currently asserts all sub-optimizers
632+ # share the same config. Revisit this design now that emerging
633+ # optimizers mix different optimizer types (e.g. Muon + Adam).
634+ # For now, reset to the top-level config so the assertion holds.
635+ if not use_layer_wise and hasattr (result , 'config' ):
636+ result .config = config
637+ results .append (result )
638+
639+ if use_layer_wise :
640+ base_optimizers , init_fns = (), ()
641+ if results :
642+ base_optimizers , init_fns = zip (* results )
643+ log_single_rank (
644+ logger , logging .INFO , f'Using LayerWiseDistributedOptimizer for { eopt_name } '
645+ )
646+ return LayerWiseDistributedOptimizer (
647+ list (base_optimizers ), config , pg_collection , init_state_fn_list = list (init_fns )
648+ )
649+
650+ return ChainedOptimizer (results )
651+
652+
505653def get_megatron_optimizer (
506654 config : OptimizerConfig ,
507655 model_chunks : List [MegatronModule ],
@@ -512,7 +660,10 @@ def get_megatron_optimizer(
512660) -> MegatronOptimizer :
513661 """Retrieve the Megatron optimizer for model chunks.
514662
663+ Handles both standard optimizers (Adam, SGD) and emerging optimizers (e.g. Muon).
515664 We use separate optimizers for expert parameters and non-expert parameters.
665+ For emerging optimizers with ``config.use_layer_wise_distributed_optimizer=True``,
666+ the optimizer is automatically wrapped with :class:`LayerWiseDistributedOptimizer`.
516667
517668 Args:
518669 config (OptimizerConfig): optimizer configuration object.
@@ -529,10 +680,25 @@ def get_megatron_optimizer(
529680 Instance of MegatronOptimizer.
530681 """
531682
532- log_single_rank (logger , logging .INFO , f'Setting up optimizer with config { config } ' )
683+ # None → apply standard defaults. To extend defaults with custom overrides,
684+ # start from get_standard_config_overrides(config) and merge yours in.
685+ if config_overrides is None :
686+ config_overrides = get_standard_config_overrides (config )
533687
534688 check_config_overrides_consistency (config , config_overrides )
535689
690+ # TODO: the standard and emerging optimizer paths handle pg_collection differently;
691+ # unify them so both use a single pg_collection-based flow.
692+ if config .optimizer not in ('adam' , 'sgd' ):
693+ return _get_megatron_emerging_optimizer (
694+ config = config ,
695+ model_chunks = model_chunks ,
696+ config_overrides = config_overrides ,
697+ pg_collection = pg_collection ,
698+ )
699+
700+ log_single_rank (logger , logging .INFO , f'Setting up optimizer with config { config } ' )
701+
536702 # Separate out first model chunk if overlapping param AG with optimizer step.
537703 if config .overlap_param_gather_with_optimizer_step :
538704 all_dense_model_chunks = [[model_chunks [0 ]], model_chunks [1 :]]
0 commit comments