Skip to content

Commit 11de188

Browse files
jstjohnmaanug-nv
authored andcommitted
Implementation of a more flexible optimizer/scheduler override system (NVIDIA#2723)
Signed-off-by: John St John <jstjohn@nvidia.com> Signed-off-by: John St. John <jstjohn@nvidia.com>
1 parent e395f27 commit 11de188

File tree

8 files changed

+471
-121
lines changed

8 files changed

+471
-121
lines changed

megatron/core/optimizer/__init__.py

Lines changed: 136 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import warnings
55
from dataclasses import astuple
6-
from typing import Callable, Dict, List, Optional, Tuple, Union
6+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
77

88
import torch
99
from torch.optim import SGD as CPUSGD
@@ -35,6 +35,11 @@
3535

3636
from megatron.core import parallel_state
3737
from megatron.core.optimizer.cpu_offloading.hybrid_optimizer import HybridDeviceOptimizer
38+
from megatron.core.optimizer_param_scheduler import (
39+
ParamGroupOverride,
40+
combine_param_group_overrides,
41+
param_group_override_to_tuple,
42+
)
3843
from megatron.core.process_groups_config import ProcessGroupCollection
3944
from megatron.core.transformer.fsdp_dtensor_checkpoint import get_global_unique_param_name
4045

@@ -50,114 +55,116 @@
5055
MegatronOptimizer,
5156
param_group_identifier_keys,
5257
)
53-
from .optimizer_config import AdamOptimizerConfig, OptimizerConfig, ParamKey, SGDOptimizerConfig
58+
from .optimizer_config import (
59+
AdamOptimizerConfig,
60+
OptimizerConfig,
61+
ParamKey,
62+
ParamPredicate,
63+
SGDOptimizerConfig,
64+
)
5465

5566
logger = logging.getLogger(__name__)
5667

5768

58-
def _matches(param: torch.nn.Parameter, param_name: str, param_key: ParamKey) -> bool:
59-
"""Returns true if passed-in parameter (with name) matches `param_key`.
69+
def get_standard_config_overrides(
70+
decoupled_lr: float | None = None, decoupled_min_lr: float | None = None
71+
) -> Dict[ParamKey, ParamGroupOverride]:
72+
"""Get standard config overrides for the optimizer, handling decoupled LR and common wd skips.
6073
6174
Args:
62-
param (torch.nn.Parameter): Handle to parameter object.
63-
param_name (str): Name of parameter in underlying PyTorch module.
64-
param_key (ParamKey): ParamKey object.
75+
decoupled_lr (float | None): decoupled learning rate.
76+
decoupled_min_lr (float | None): decoupled minimum learning rate.
6577
6678
Returns:
67-
bool: True if parameter matches passed-in param_key.
79+
Dict[ParamKey, ParamGroupOverride]: standard config overrides.
6880
"""
81+
config_overrides: Optional[Dict[ParamKey, ParamGroupOverride]] = {}
82+
if decoupled_lr is not None:
83+
decoupled_lr_config: ParamGroupOverride = {"max_lr": decoupled_lr}
84+
decoupled_param_key = ParamKey(attr="is_embedding_or_output_parameter")
85+
if decoupled_min_lr is not None:
86+
decoupled_lr_config["min_lr"] = decoupled_min_lr
87+
config_overrides[decoupled_param_key] = decoupled_lr_config
88+
89+
# Next construct the standard param group overrides for no weight decay on bias parameters
90+
# as well as any length 1 parameters.
91+
param_length_1_match = ParamPredicate(
92+
name="param_len_1", fn=lambda param: len(param.shape) == 1
93+
)
94+
param_wd_mult_key = ParamKey(name="*.bias", predicate=param_length_1_match)
95+
config_overrides[param_wd_mult_key] = ParamGroupOverride(wd_mult=0.0)
6996

70-
# Check if name matches.
71-
if isinstance(param_key.name, str):
72-
target_names = [param_key.name]
73-
else:
74-
target_names = list(param_key.name)
75-
for target_name in target_names:
76-
if param_name in target_name:
77-
return True
78-
79-
# Check if attribute matches.
80-
if isinstance(param_key.attr, str):
81-
target_attrs = [param_key.attr]
82-
else:
83-
target_attrs = list(param_key.attr)
84-
for target_attr in target_attrs:
85-
if getattr(param, target_attr, False):
86-
return True
87-
88-
return False
97+
return config_overrides
8998

9099

91100
def _get_param_groups(
92101
model_chunks: List[MegatronModule],
93102
config: OptimizerConfig,
94-
config_overrides: Optional[Dict[ParamKey, OptimizerConfig]],
103+
config_overrides: Optional[Dict[ParamKey, ParamGroupOverride]],
95104
) -> List[Dict]:
96105
"""Create parameter groups for optimizer.
97106
98107
Creates parameter groups from provided optimizer config object.
99108
109+
NOTE There can be more than one match between a ParamKey and a parameter.
110+
What we do is merge all of the matching ParamKey overrides into a single ParamGroupOverride
111+
for that parameter and use that as the key for that parameter. Any parameters that get
112+
the same set of merged overrides will be mapped into the same parameter group.
113+
100114
Args:
101115
model_chunks (List[MegatronModule]): model chunks to create parameter
102116
groups for.
103117
config (OptimizerConfig): optimizer configuration object.
104-
config_overrides (Optional[Dict[LayerKey, OptimizerConfig]): optimizer overrides,
105-
specified on a per-layer basis.
118+
config_overrides (Optional[Dict[ParamKey, ParamGroupOverride]): optimizer overrides,
119+
specified on a per-layer basis. NOTE: if you want to skip applying weight decay on bias
120+
and length 1 parameters, and also do not want to do any other overrides, set this to an
121+
empty dictionary rather than the default value of None.
106122
Returns:
107123
List of parameter groups.
108124
"""
109125

110-
# Map (wd_mult, is_expert_parallel, param_group_hyperparameters_config) to params.
126+
# Map (pg_overrides, is_expert_parallel) to params.
111127
params_map = {}
112-
configs_map = {}
128+
129+
if config_overrides is None:
130+
# TODO remove this default behavior eventually.
131+
# This is only needed for backwards compatibility with the old config overrides API where
132+
# the config_overrides argument by default lead to bias parameters and length 1 parameters.
133+
# We assume that users of decoupled LR already provide config overrides so will adapt
134+
# to the new API.
135+
config_overrides = get_standard_config_overrides()
113136

114137
for model_chunk in model_chunks:
115138
for name, param in model_chunk.named_parameters():
116139
if not param.requires_grad:
117140
continue
118141

119142
uses_default_config = False
120-
# Get optimizer config for this parameter.
121-
if config_overrides is None:
122-
config_for_param = config
123-
uses_default_config = True
143+
# Get optimizer config overrides for this parameter.
144+
param_overrides_list: list[ParamGroupOverride] = []
145+
if config_overrides is not None:
146+
for param_key, param_override in config_overrides.items():
147+
if param_key.matches(param, name):
148+
param_overrides_list.append(param_override)
149+
150+
if param_overrides_list:
151+
param_override: ParamGroupOverride | None = combine_param_group_overrides(
152+
param_overrides_list
153+
)
124154
else:
125-
config_for_param = None
126-
for param_key in config_overrides:
127-
if _matches(param, name, param_key):
128-
config_for_param = config_overrides[param_key]
129-
break
130-
# Fall back to default config.
131-
if config_for_param is None:
132-
config_for_param = config
133-
uses_default_config = True
155+
param_override = None
134156

135157
is_expert_parallel = not getattr(param, 'allreduce', True)
136158

137-
# TODO: Make sure there is a way to support old no_weight_decay_func functionality
138-
# and default_skip_embedding_weight_decay:
139-
# or (default_skip_embedding_weight_decay and "embedding" in name)
140-
no_wd = name.endswith(".bias") or len(param.shape) == 1
141-
if not no_wd:
142-
wd_mult = 1.0
143-
else:
144-
wd_mult = 0.0
145-
146-
# Create config_tuple that is hash-able. Remove timers object before
147-
# creating config_tuple.
148-
config_for_param_copy = copy.deepcopy(config_for_param)
149-
config_for_param_copy.timers = None
150-
config_tuple = astuple(config_for_param_copy)
151-
key = (wd_mult, is_expert_parallel, config_tuple)
159+
# Create config_tuple that is hash-able, and has a consistent ordering of the keys.
160+
param_override_tuple: tuple[tuple[str, Any], ...] | None = (
161+
param_group_override_to_tuple(param_override)
162+
)
163+
key = (param_override_tuple, is_expert_parallel)
152164
if key not in params_map:
153165
params_map[key] = []
154166
params_map[key].append(param)
155167

156-
if key in configs_map:
157-
assert (config_for_param, uses_default_config) == configs_map[key]
158-
else:
159-
configs_map[key] = (config_for_param, uses_default_config)
160-
161168
# Distributed checkpoint requires all ranks to have the same param groups,
162169
# so we need to align the param groups across ranks, otherwise we may have
163170
# runtime error when loading the checkpoint or numerical error when resuming training.
@@ -168,34 +175,47 @@ def _get_param_groups(
168175
for key in keys:
169176
if key not in params_key:
170177
params_key.append(key)
171-
178+
# Need to pick one of the param_override_tuples to use for the param group.
172179
param_groups = []
173-
for key in params_key:
174-
wd_mult, is_expert_parallel, _ = key
180+
# Sort keys, None first.
181+
for key in sorted(params_key, key=lambda x: (x[0] is not None, x[0])):
182+
param_override_tuple, is_expert_parallel = key
175183
params = params_map[key] if key in params_map else []
176-
config, uses_default_config = None, True
177-
if key not in configs_map:
178-
assert params == []
184+
if param_override_tuple is None:
185+
param_override: ParamGroupOverride = {}
179186
else:
180-
config, uses_default_config = configs_map[key]
181-
assert config is not None
187+
param_override: ParamGroupOverride = {k: v for (k, v) in param_override_tuple}
188+
189+
# False if param_group_override is None or empty tuple or if we do not modify the
190+
# LR schedule.
191+
# NOTE: "default_config" is used for logging the learning rate in training.py.
192+
# so set to True if we do not modify the learning rate.
193+
# if param_group['default_config']:
194+
# learning_rate = param_group['lr']
195+
uses_default_lr_schedule: bool = (not bool(param_override_tuple)) or not any(
196+
["lr" in k for k in param_override]
197+
)
182198

183199
# TODO: Remove "backwards compatible" fields below eventually.
200+
default_config: ParamGroupOverride = {
201+
'wd_mult': 1.0,
202+
'lr_mult': 1.0,
203+
'is_decoupled_lr': False,
204+
# The following two fields may be important to keep even when we remove the
205+
# above "backwards compatible" fields.
206+
"max_lr": config.lr, # user may override this in param_override
207+
"min_lr": config.min_lr, # user may override this in param_override
208+
}
209+
assert (
210+
"params" not in param_override
211+
), "'params' should not be in param_override, this is a protected key"
184212
param_group = {
185213
'params': params,
186-
'wd_mult': wd_mult, # For backwards compatibility.
187-
'lr_mult': 1.0, # For backwards compatibility.
188214
'is_expert_parallel': is_expert_parallel,
189-
'is_decoupled_lr': False, # For backwards compatibility.
190-
'default_config': uses_default_config,
215+
'default_config': uses_default_lr_schedule,
216+
**default_config,
217+
**param_override, # keep **param_override last so that users can override other fields.
191218
}
192-
193-
# Stick relevant fields into param_group from config object.
194-
if config is not None:
195-
param_group['max_lr'] = config.lr
196-
param_group['min_lr'] = config.min_lr
197-
# TODO: Add other relevant arguments (e.g., weight decay, optimizer)
198-
# here as well.
199219
param_groups.append(param_group)
200220

201221
return param_groups
@@ -205,7 +225,7 @@ def _get_param_groups_and_buffers(
205225
model_chunks: List[MegatronModule],
206226
model_chunk_offset: int,
207227
config: OptimizerConfig,
208-
config_overrides: Optional[Dict[ParamKey, OptimizerConfig]],
228+
config_overrides: Optional[Dict[ParamKey, ParamGroupOverride]],
209229
filter_fn: Callable,
210230
buffer_name: str,
211231
) -> Tuple[List[Dict], Dict[int, List[_ParamAndGradBuffer]]]:
@@ -216,8 +236,8 @@ def _get_param_groups_and_buffers(
216236
groups for.
217237
model_chunk_offset (int): offset of model_chunks in global model_chunks list.
218238
config (OptimizerConfig): optimizer configuration object.
219-
config_overrides (Optional[Dict[LayerKey, OptimizerConfig]): optimizer overrides,
220-
specified on a per-layer basis.
239+
config_overrides (Optional[Dict[ParamKey, ParamGroupOverride]): optimizer/scheduler
240+
overrides, specified on the basis of ParamKey matches with each parameter.
221241
lr (float): learning rate.
222242
min_lr (float): minimum learning rate.
223243
filter_fn (callable): filtering function for param_groups.
@@ -439,10 +459,37 @@ def init_state_fn(opt, config=None):
439459
return optimizer
440460

441461

462+
def check_config_overrides_consistency(
463+
config: OptimizerConfig, config_overrides: Optional[Dict[ParamKey, ParamGroupOverride]]
464+
):
465+
"""Check if the config overrides are consistent with the config."""
466+
467+
# TODO: Remove `optimizer` from this eventually (e.g., if we use Muon for some layers and
468+
# Adam for other layers). This would need some more refactoring to work though (param_groups
469+
# filtered by optimizer passed into _get_megatron_optimizer_based_on_param_groups).
470+
if config_overrides is not None:
471+
fields_to_check_for_consistency = [
472+
'overlap_param_gather_with_optimizer_step',
473+
'optimizer',
474+
'optimizer_cpu_offload',
475+
]
476+
for field_name in fields_to_check_for_consistency:
477+
base_field = getattr(config, field_name, None)
478+
all_config_overrides = list(config_overrides.values())
479+
for config_override in all_config_overrides:
480+
if field_name in config_override:
481+
field = config_override[field_name]
482+
if field != base_field:
483+
raise ValueError(
484+
f"Field {field_name} should not be overriden in a config override."
485+
)
486+
return True
487+
488+
442489
def get_megatron_optimizer(
443490
config: OptimizerConfig,
444491
model_chunks: List[MegatronModule],
445-
config_overrides: Optional[Dict[ParamKey, OptimizerConfig]] = None,
492+
config_overrides: Optional[Dict[ParamKey, ParamGroupOverride]] = None,
446493
use_gloo_process_groups: bool = True,
447494
pg_collection: Optional[ProcessGroupCollection] = None,
448495
dump_param_to_param_group_map: Optional[str] = None,
@@ -468,19 +515,7 @@ def get_megatron_optimizer(
468515

469516
log_single_rank(logger, logging.INFO, f'Setting up optimizer with config {config}')
470517

471-
# TODO: Remove `optimizer` from this eventually (e.g., if we use Muon for some layers and
472-
# Adam for other layers). This would need some more refactoring to work though (param_groups
473-
# filtered by optimizer passed into _get_megatron_optimizer_based_on_param_groups).
474-
fields_to_check_for_consistency = [
475-
'overlap_param_gather_with_optimizer_step',
476-
'optimizer',
477-
'optimizer_cpu_offload',
478-
]
479-
for field_name in fields_to_check_for_consistency:
480-
field = getattr(config, field_name, None)
481-
if config_overrides is not None:
482-
all_configs = list(config_overrides.values())
483-
assert all([getattr(x, field_name, None) == field for x in all_configs])
518+
check_config_overrides_consistency(config, config_overrides)
484519

485520
# Separate out first model chunk if overlapping param AG with optimizer step.
486521
if config.overlap_param_gather_with_optimizer_step:

0 commit comments

Comments
 (0)