Skip to content

Commit a34da99

Browse files
support apply wd to qk layernorm
1 parent 876a046 commit a34da99

File tree

4 files changed

+121
-11
lines changed

4 files changed

+121
-11
lines changed

megatron/core/optimizer/__init__.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,33 @@ def _matches(param: torch.nn.Parameter, param_name: str, param_key: ParamKey) ->
8888
return False
8989

9090

91+
def _get_no_wd_cond_fn(no_weight_decay_cond_type):
92+
"""Get the no weight decay condition function."""
93+
94+
if no_weight_decay_cond_type == 'apply_wd_to_qk_layernorm':
95+
96+
def no_wd_cond_fn(name, param):
97+
if "q_layernorm" in name or "k_layernorm" in name:
98+
"""Applies weight decay to qk layernorm as a special case"""
99+
no_wd = False
100+
else:
101+
no_wd = name.endswith(".bias") or len(param.shape) == 1
102+
return no_wd
103+
104+
elif no_weight_decay_cond_type is None:
105+
106+
def no_wd_cond_fn(name, param):
107+
return name.endswith(".bias") or len(param.shape) == 1
108+
109+
else:
110+
raise ValueError(f"Unknown no_weight_decay_cond_type: {no_weight_decay_cond_type}")
111+
112+
return no_wd_cond_fn
113+
114+
91115
def _get_param_groups(
92116
model_chunks: List[MegatronModule],
93-
config: OptimizerConfig,
117+
optimizer_config: OptimizerConfig,
94118
config_overrides: Optional[Dict[ParamKey, OptimizerConfig]],
95119
) -> List[Dict]:
96120
"""Create parameter groups for optimizer.
@@ -100,7 +124,7 @@ def _get_param_groups(
100124
Args:
101125
model_chunks (List[MegatronModule]): model chunks to create parameter
102126
groups for.
103-
config (OptimizerConfig): optimizer configuration object.
127+
optimizer_config (OptimizerConfig): optimizer configuration object.
104128
config_overrides (Optional[Dict[LayerKey, OptimizerConfig]): optimizer overrides,
105129
specified on a per-layer basis.
106130
Returns:
@@ -119,7 +143,7 @@ def _get_param_groups(
119143
uses_default_config = False
120144
# Get optimizer config for this parameter.
121145
if config_overrides is None:
122-
config_for_param = config
146+
config_for_param = optimizer_config
123147
uses_default_config = True
124148
else:
125149
config_for_param = None
@@ -129,15 +153,16 @@ def _get_param_groups(
129153
break
130154
# Fall back to default config.
131155
if config_for_param is None:
132-
config_for_param = config
156+
config_for_param = optimizer_config
133157
uses_default_config = True
134158

135159
is_expert_parallel = not getattr(param, 'allreduce', True)
136160

137161
# TODO: Make sure there is a way to support old no_weight_decay_func functionality
138162
# and default_skip_embedding_weight_decay:
139163
# or (default_skip_embedding_weight_decay and "embedding" in name)
140-
no_wd = name.endswith(".bias") or len(param.shape) == 1
164+
no_wd_cond_fn = _get_no_wd_cond_fn(optimizer_config.no_weight_decay_cond)
165+
no_wd = no_wd_cond_fn(name, param)
141166
if not no_wd:
142167
wd_mult = 1.0
143168
else:
@@ -173,12 +198,12 @@ def _get_param_groups(
173198
for key in params_key:
174199
wd_mult, is_expert_parallel, _ = key
175200
params = params_map[key] if key in params_map else []
176-
config, uses_default_config = None, True
201+
param_config, uses_default_config = None, True
177202
if key not in configs_map:
178203
assert params == []
179204
else:
180-
config, uses_default_config = configs_map[key]
181-
assert config is not None
205+
param_config, uses_default_config = configs_map[key]
206+
assert param_config is not None
182207

183208
# TODO: Remove "backwards compatible" fields below eventually.
184209
param_group = {
@@ -191,9 +216,9 @@ def _get_param_groups(
191216
}
192217

193218
# Stick relevant fields into param_group from config object.
194-
if config is not None:
195-
param_group['max_lr'] = config.lr
196-
param_group['min_lr'] = config.min_lr
219+
if param_config is not None:
220+
param_group['max_lr'] = param_config.lr
221+
param_group['min_lr'] = param_config.min_lr
197222
# TODO: Add other relevant arguments (e.g., weight decay, optimizer)
198223
# here as well.
199224
param_groups.append(param_group)

megatron/core/optimizer/optimizer_config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@ class OptimizerConfig:
4141
weight_decay: float = 0.01
4242
"""Weight decay coefficient for L2 regularization."""
4343

44+
no_weight_decay_cond: Optional[str] = None
45+
"""Condition for whether a parameter should not perform weight decay.
46+
Supported conditions:
47+
- None (default): apply weight decay to 1D weights, biases,
48+
and embedding weights.
49+
- "apply_wd_to_qk_layernorm": additionally apply weight decay to
50+
qk layernorm as a special case.
51+
"""
52+
4453
##############
4554
# Precision
4655
##############

megatron/training/arguments.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2000,6 +2000,11 @@ def _add_regularization_args(parser):
20002000
group.add_argument('--weight-decay-incr-style', type=str, default='constant',
20012001
choices=['constant', 'linear', 'cosine'],
20022002
help='Weight decay increment function.')
2003+
group.add_argument('--no-weight-decay-cond-type', type=str, choices=['apply_wd_to_qk_layernorm'],
2004+
help='Type of no weight decay condition. Choices: '
2005+
'None (default): apply weight decay to 1D weights and biases.'
2006+
'"apply_wd_to_qk_layernorm": additionally apply weight decay to '
2007+
'qk layernorm as a special case.')
20032008
group.add_argument('--clip-grad', type=float, default=1.0,
20042009
help='Gradient clipping based on global L2 norm.')
20052010
group.add_argument('--adam-beta1', type=float, default=0.9,

tests/unit_tests/test_optimizer.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,3 +598,74 @@ def test_get_megatron_optimizer_custom_process_groups_validation():
598598
use_gloo_process_groups=True, # Should be False when using custom groups
599599
pg_collection=pg_collection_complete,
600600
)
601+
602+
603+
class QKLayerNormModel(nn.Module):
604+
"""A model with q_layernorm, k_layernorm, regular layernorm and bias parameters
605+
to test the 'apply_wd_to_qk_layernorm' no_weight_decay_cond option.
606+
"""
607+
608+
def __init__(self, hidden_size=64):
609+
super().__init__()
610+
# q_layernorm and k_layernorm should have wd_mult=1.0 when apply_wd_to_qk_layernorm is set
611+
self.q_layernorm = nn.LayerNorm(hidden_size, bias=True)
612+
self.k_layernorm = nn.LayerNorm(hidden_size, bias=True)
613+
# Regular layernorm should have wd_mult=0.0 (1D params)
614+
self.regular_layernorm = nn.LayerNorm(hidden_size, bias=False)
615+
# Linear layer: weight should have wd_mult=1.0, bias should have wd_mult=0.0
616+
self.linear = nn.Linear(hidden_size, hidden_size, bias=True)
617+
618+
619+
def test_no_weight_decay_cond_apply_wd_to_qk_layernorm():
620+
"""
621+
Test that no_weight_decay_cond='apply_wd_to_qk_layernorm' correctly assigns
622+
wd_mult=1.0 to q_layernorm and k_layernorm parameters while other 1D params
623+
(bias, regular layernorm) have wd_mult=0.0.
624+
625+
This test uses get_megatron_optimizer to build an optimizer and then checks
626+
the param_groups to verify the wd_mult assignment.
627+
"""
628+
world = int(os.getenv('WORLD_SIZE', '1'))
629+
rank = int(os.getenv('RANK', '0'))
630+
_init_distributed(world, rank)
631+
Utils.initialize_model_parallel()
632+
633+
# Create model with q_layernorm, k_layernorm, and regular layernorm
634+
model = QKLayerNormModel(hidden_size=64).bfloat16().cuda()
635+
model.requires_grad_(True)
636+
637+
ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=True)
638+
model = DistributedDataParallel(
639+
TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model
640+
)
641+
642+
# Create optimizer config with no_weight_decay_cond='apply_wd_to_qk_layernorm'
643+
optimizer_config = OptimizerConfig(
644+
optimizer='adam',
645+
lr=0.01,
646+
bf16=True,
647+
use_distributed_optimizer=False,
648+
no_weight_decay_cond='apply_wd_to_qk_layernorm',
649+
)
650+
651+
# Build optimizer
652+
optim = get_megatron_optimizer(optimizer_config, [model])
653+
654+
# Count params by wd_mult
655+
wd_mult_1_count = 0 # Params with weight decay
656+
wd_mult_0_count = 0 # Params without weight decay
657+
658+
for group in optim.param_groups:
659+
wd_mult = group['wd_mult']
660+
num_params = len(group['params'])
661+
if wd_mult == 1.0:
662+
wd_mult_1_count += num_params
663+
else:
664+
wd_mult_0_count += num_params
665+
666+
# Expected:
667+
# wd_mult=1.0: q_layernorm.weight, q_layernorm.bias, k_layernorm.weight,
668+
# k_layernorm.bias, linear.weight = 5 params
669+
# wd_mult=0.0: regular_layernorm.weight, linear.bias = 2 params
670+
assert wd_mult_1_count == 5, f"Expected 5 params with wd_mult=1.0, but got {wd_mult_1_count}"
671+
assert wd_mult_0_count == 2, f"Expected 3 params with wd_mult=0.0, but got {wd_mult_0_count}"

0 commit comments

Comments
 (0)