33import logging
44import warnings
55from dataclasses import astuple
6- from typing import Callable , Dict , List , Optional , Tuple , Union
6+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
77
88import torch
99from torch .optim import SGD as CPUSGD
3535
3636from megatron .core import parallel_state
3737from megatron .core .optimizer .cpu_offloading .hybrid_optimizer import HybridDeviceOptimizer
38+ from megatron .core .optimizer_param_scheduler import (
39+ ParamGroupOverride ,
40+ combine_param_group_overrides ,
41+ param_group_override_to_tuple ,
42+ )
3843from megatron .core .process_groups_config import ProcessGroupCollection
3944from megatron .core .transformer .fsdp_dtensor_checkpoint import get_global_unique_param_name
4045
5055 MegatronOptimizer ,
5156 param_group_identifier_keys ,
5257)
53- from .optimizer_config import AdamOptimizerConfig , OptimizerConfig , ParamKey , SGDOptimizerConfig
58+ from .optimizer_config import (
59+ AdamOptimizerConfig ,
60+ OptimizerConfig ,
61+ ParamKey ,
62+ ParamPredicate ,
63+ SGDOptimizerConfig ,
64+ )
5465
5566logger = logging .getLogger (__name__ )
5667
5768
58- def _matches (param : torch .nn .Parameter , param_name : str , param_key : ParamKey ) -> bool :
59- """Returns true if passed-in parameter (with name) matches `param_key`.
69+ def get_standard_config_overrides (
70+ decoupled_lr : float | None = None , decoupled_min_lr : float | None = None
71+ ) -> Dict [ParamKey , ParamGroupOverride ]:
72+ """Get standard config overrides for the optimizer, handling decoupled LR and common wd skips.
6073
6174 Args:
62- param (torch.nn.Parameter): Handle to parameter object.
63- param_name (str): Name of parameter in underlying PyTorch module.
64- param_key (ParamKey): ParamKey object.
75+ decoupled_lr (float | None): decoupled learning rate.
76+ decoupled_min_lr (float | None): decoupled minimum learning rate.
6577
6678 Returns:
67- bool: True if parameter matches passed-in param_key .
79+ Dict[ParamKey, ParamGroupOverride]: standard config overrides .
6880 """
81+ config_overrides : Optional [Dict [ParamKey , ParamGroupOverride ]] = {}
82+ if decoupled_lr is not None :
83+ decoupled_lr_config : ParamGroupOverride = {"max_lr" : decoupled_lr }
84+ decoupled_param_key = ParamKey (attr = "is_embedding_or_output_parameter" )
85+ if decoupled_min_lr is not None :
86+ decoupled_lr_config ["min_lr" ] = decoupled_min_lr
87+ config_overrides [decoupled_param_key ] = decoupled_lr_config
88+
89+ # Next construct the standard param group overrides for no weight decay on bias parameters
90+ # as well as any length 1 parameters.
91+ param_length_1_match = ParamPredicate (
92+ name = "param_len_1" , fn = lambda param : len (param .shape ) == 1
93+ )
94+ param_wd_mult_key = ParamKey (name = "*.bias" , predicate = param_length_1_match )
95+ config_overrides [param_wd_mult_key ] = ParamGroupOverride (wd_mult = 0.0 )
6996
70- # Check if name matches.
71- if isinstance (param_key .name , str ):
72- target_names = [param_key .name ]
73- else :
74- target_names = list (param_key .name )
75- for target_name in target_names :
76- if param_name in target_name :
77- return True
78-
79- # Check if attribute matches.
80- if isinstance (param_key .attr , str ):
81- target_attrs = [param_key .attr ]
82- else :
83- target_attrs = list (param_key .attr )
84- for target_attr in target_attrs :
85- if getattr (param , target_attr , False ):
86- return True
87-
88- return False
97+ return config_overrides
8998
9099
91100def _get_param_groups (
92101 model_chunks : List [MegatronModule ],
93102 config : OptimizerConfig ,
94- config_overrides : Optional [Dict [ParamKey , OptimizerConfig ]],
103+ config_overrides : Optional [Dict [ParamKey , ParamGroupOverride ]],
95104) -> List [Dict ]:
96105 """Create parameter groups for optimizer.
97106
98107 Creates parameter groups from provided optimizer config object.
99108
109+ NOTE There can be more than one match between a ParamKey and a parameter.
110+ What we do is merge all of the matching ParamKey overrides into a single ParamGroupOverride
111+ for that parameter and use that as the key for that parameter. Any parameters that get
112+ the same set of merged overrides will be mapped into the same parameter group.
113+
100114 Args:
101115 model_chunks (List[MegatronModule]): model chunks to create parameter
102116 groups for.
103117 config (OptimizerConfig): optimizer configuration object.
104- config_overrides (Optional[Dict[LayerKey, OptimizerConfig]): optimizer overrides,
105- specified on a per-layer basis.
118+ config_overrides (Optional[Dict[ParamKey, ParamGroupOverride]): optimizer overrides,
119+ specified on a per-layer basis. NOTE: if you want to skip applying weight decay on bias
120+ and length 1 parameters, and also do not want to do any other overrides, set this to an
121+ empty dictionary rather than the default value of None.
106122 Returns:
107123 List of parameter groups.
108124 """
109125
110- # Map (wd_mult , is_expert_parallel, param_group_hyperparameters_config ) to params.
126+ # Map (pg_overrides , is_expert_parallel) to params.
111127 params_map = {}
112- configs_map = {}
128+
129+ if config_overrides is None :
130+ # TODO remove this default behavior eventually.
131+ # This is only needed for backwards compatibility with the old config overrides API where
132+ # the config_overrides argument by default lead to bias parameters and length 1 parameters.
133+ # We assume that users of decoupled LR already provide config overrides so will adapt
134+ # to the new API.
135+ config_overrides = get_standard_config_overrides ()
113136
114137 for model_chunk in model_chunks :
115138 for name , param in model_chunk .named_parameters ():
116139 if not param .requires_grad :
117140 continue
118141
119142 uses_default_config = False
120- # Get optimizer config for this parameter.
121- if config_overrides is None :
122- config_for_param = config
123- uses_default_config = True
143+ # Get optimizer config overrides for this parameter.
144+ param_overrides_list : list [ParamGroupOverride ] = []
145+ if config_overrides is not None :
146+ for param_key , param_override in config_overrides .items ():
147+ if param_key .matches (param , name ):
148+ param_overrides_list .append (param_override )
149+
150+ if param_overrides_list :
151+ param_override : ParamGroupOverride | None = combine_param_group_overrides (
152+ param_overrides_list
153+ )
124154 else :
125- config_for_param = None
126- for param_key in config_overrides :
127- if _matches (param , name , param_key ):
128- config_for_param = config_overrides [param_key ]
129- break
130- # Fall back to default config.
131- if config_for_param is None :
132- config_for_param = config
133- uses_default_config = True
155+ param_override = None
134156
135157 is_expert_parallel = not getattr (param , 'allreduce' , True )
136158
137- # TODO: Make sure there is a way to support old no_weight_decay_func functionality
138- # and default_skip_embedding_weight_decay:
139- # or (default_skip_embedding_weight_decay and "embedding" in name)
140- no_wd = name .endswith (".bias" ) or len (param .shape ) == 1
141- if not no_wd :
142- wd_mult = 1.0
143- else :
144- wd_mult = 0.0
145-
146- # Create config_tuple that is hash-able. Remove timers object before
147- # creating config_tuple.
148- config_for_param_copy = copy .deepcopy (config_for_param )
149- config_for_param_copy .timers = None
150- config_tuple = astuple (config_for_param_copy )
151- key = (wd_mult , is_expert_parallel , config_tuple )
159+ # Create config_tuple that is hash-able, and has a consistent ordering of the keys.
160+ param_override_tuple : tuple [tuple [str , Any ], ...] | None = (
161+ param_group_override_to_tuple (param_override )
162+ )
163+ key = (param_override_tuple , is_expert_parallel )
152164 if key not in params_map :
153165 params_map [key ] = []
154166 params_map [key ].append (param )
155167
156- if key in configs_map :
157- assert (config_for_param , uses_default_config ) == configs_map [key ]
158- else :
159- configs_map [key ] = (config_for_param , uses_default_config )
160-
161168 # Distributed checkpoint requires all ranks to have the same param groups,
162169 # so we need to align the param groups across ranks, otherwise we may have
163170 # runtime error when loading the checkpoint or numerical error when resuming training.
@@ -168,34 +175,47 @@ def _get_param_groups(
168175 for key in keys :
169176 if key not in params_key :
170177 params_key .append (key )
171-
178+ # Need to pick one of the param_override_tuples to use for the param group.
172179 param_groups = []
173- for key in params_key :
174- wd_mult , is_expert_parallel , _ = key
180+ # Sort keys, None first.
181+ for key in sorted (params_key , key = lambda x : (x [0 ] is not None , x [0 ])):
182+ param_override_tuple , is_expert_parallel = key
175183 params = params_map [key ] if key in params_map else []
176- config , uses_default_config = None , True
177- if key not in configs_map :
178- assert params == []
184+ if param_override_tuple is None :
185+ param_override : ParamGroupOverride = {}
179186 else :
180- config , uses_default_config = configs_map [key ]
181- assert config is not None
187+ param_override : ParamGroupOverride = {k : v for (k , v ) in param_override_tuple }
188+
189+ # False if param_group_override is None or empty tuple or if we do not modify the
190+ # LR schedule.
191+ # NOTE: "default_config" is used for logging the learning rate in training.py.
192+ # so set to True if we do not modify the learning rate.
193+ # if param_group['default_config']:
194+ # learning_rate = param_group['lr']
195+ uses_default_lr_schedule : bool = (not bool (param_override_tuple )) or not any (
196+ ["lr" in k for k in param_override ]
197+ )
182198
183199 # TODO: Remove "backwards compatible" fields below eventually.
200+ default_config : ParamGroupOverride = {
201+ 'wd_mult' : 1.0 ,
202+ 'lr_mult' : 1.0 ,
203+ 'is_decoupled_lr' : False ,
204+ # The following two fields may be important to keep even when we remove the
205+ # above "backwards compatible" fields.
206+ "max_lr" : config .lr , # user may override this in param_override
207+ "min_lr" : config .min_lr , # user may override this in param_override
208+ }
209+ assert (
210+ "params" not in param_override
211+ ), "'params' should not be in param_override, this is a protected key"
184212 param_group = {
185213 'params' : params ,
186- 'wd_mult' : wd_mult , # For backwards compatibility.
187- 'lr_mult' : 1.0 , # For backwards compatibility.
188214 'is_expert_parallel' : is_expert_parallel ,
189- 'is_decoupled_lr' : False , # For backwards compatibility.
190- 'default_config' : uses_default_config ,
215+ 'default_config' : uses_default_lr_schedule ,
216+ ** default_config ,
217+ ** param_override , # keep **param_override last so that users can override other fields.
191218 }
192-
193- # Stick relevant fields into param_group from config object.
194- if config is not None :
195- param_group ['max_lr' ] = config .lr
196- param_group ['min_lr' ] = config .min_lr
197- # TODO: Add other relevant arguments (e.g., weight decay, optimizer)
198- # here as well.
199219 param_groups .append (param_group )
200220
201221 return param_groups
@@ -205,7 +225,7 @@ def _get_param_groups_and_buffers(
205225 model_chunks : List [MegatronModule ],
206226 model_chunk_offset : int ,
207227 config : OptimizerConfig ,
208- config_overrides : Optional [Dict [ParamKey , OptimizerConfig ]],
228+ config_overrides : Optional [Dict [ParamKey , ParamGroupOverride ]],
209229 filter_fn : Callable ,
210230 buffer_name : str ,
211231) -> Tuple [List [Dict ], Dict [int , List [_ParamAndGradBuffer ]]]:
@@ -216,8 +236,8 @@ def _get_param_groups_and_buffers(
216236 groups for.
217237 model_chunk_offset (int): offset of model_chunks in global model_chunks list.
218238 config (OptimizerConfig): optimizer configuration object.
219- config_overrides (Optional[Dict[LayerKey, OptimizerConfig ]): optimizer overrides,
220- specified on a per-layer basis.
239+ config_overrides (Optional[Dict[ParamKey, ParamGroupOverride ]): optimizer/scheduler
240+ overrides, specified on the basis of ParamKey matches with each parameter .
221241 lr (float): learning rate.
222242 min_lr (float): minimum learning rate.
223243 filter_fn (callable): filtering function for param_groups.
@@ -439,10 +459,37 @@ def init_state_fn(opt, config=None):
439459 return optimizer
440460
441461
462+ def check_config_overrides_consistency (
463+ config : OptimizerConfig , config_overrides : Optional [Dict [ParamKey , ParamGroupOverride ]]
464+ ):
465+ """Check if the config overrides are consistent with the config."""
466+
467+ # TODO: Remove `optimizer` from this eventually (e.g., if we use Muon for some layers and
468+ # Adam for other layers). This would need some more refactoring to work though (param_groups
469+ # filtered by optimizer passed into _get_megatron_optimizer_based_on_param_groups).
470+ if config_overrides is not None :
471+ fields_to_check_for_consistency = [
472+ 'overlap_param_gather_with_optimizer_step' ,
473+ 'optimizer' ,
474+ 'optimizer_cpu_offload' ,
475+ ]
476+ for field_name in fields_to_check_for_consistency :
477+ base_field = getattr (config , field_name , None )
478+ all_config_overrides = list (config_overrides .values ())
479+ for config_override in all_config_overrides :
480+ if field_name in config_override :
481+ field = config_override [field_name ]
482+ if field != base_field :
483+ raise ValueError (
484+ f"Field { field_name } should not be overriden in a config override."
485+ )
486+ return True
487+
488+
442489def get_megatron_optimizer (
443490 config : OptimizerConfig ,
444491 model_chunks : List [MegatronModule ],
445- config_overrides : Optional [Dict [ParamKey , OptimizerConfig ]] = None ,
492+ config_overrides : Optional [Dict [ParamKey , ParamGroupOverride ]] = None ,
446493 use_gloo_process_groups : bool = True ,
447494 pg_collection : Optional [ProcessGroupCollection ] = None ,
448495 dump_param_to_param_group_map : Optional [str ] = None ,
@@ -468,19 +515,7 @@ def get_megatron_optimizer(
468515
469516 log_single_rank (logger , logging .INFO , f'Setting up optimizer with config { config } ' )
470517
471- # TODO: Remove `optimizer` from this eventually (e.g., if we use Muon for some layers and
472- # Adam for other layers). This would need some more refactoring to work though (param_groups
473- # filtered by optimizer passed into _get_megatron_optimizer_based_on_param_groups).
474- fields_to_check_for_consistency = [
475- 'overlap_param_gather_with_optimizer_step' ,
476- 'optimizer' ,
477- 'optimizer_cpu_offload' ,
478- ]
479- for field_name in fields_to_check_for_consistency :
480- field = getattr (config , field_name , None )
481- if config_overrides is not None :
482- all_configs = list (config_overrides .values ())
483- assert all ([getattr (x , field_name , None ) == field for x in all_configs ])
518+ check_config_overrides_consistency (config , config_overrides )
484519
485520 # Separate out first model chunk if overlapping param AG with optimizer step.
486521 if config .overlap_param_gather_with_optimizer_step :
0 commit comments