@@ -88,9 +88,33 @@ def _matches(param: torch.nn.Parameter, param_name: str, param_key: ParamKey) ->
8888 return False
8989
9090
91+ def _get_no_wd_cond_fn (no_weight_decay_cond_type ):
92+ """Get the no weight decay condition function."""
93+
94+ if no_weight_decay_cond_type == 'apply_wd_to_qk_layernorm' :
95+
96+ def no_wd_cond_fn (name , param ):
97+ if "q_layernorm" in name or "k_layernorm" in name :
98+ """Applies weight decay to qk layernorm as a special case"""
99+ no_wd = False
100+ else :
101+ no_wd = name .endswith (".bias" ) or len (param .shape ) == 1
102+ return no_wd
103+
104+ elif no_weight_decay_cond_type is None :
105+
106+ def no_wd_cond_fn (name , param ):
107+ return name .endswith (".bias" ) or len (param .shape ) == 1
108+
109+ else :
110+ raise ValueError (f"Unknown no_weight_decay_cond_type: { no_weight_decay_cond_type } " )
111+
112+ return no_wd_cond_fn
113+
114+
91115def _get_param_groups (
92116 model_chunks : List [MegatronModule ],
93- config : OptimizerConfig ,
117+ optimizer_config : OptimizerConfig ,
94118 config_overrides : Optional [Dict [ParamKey , OptimizerConfig ]],
95119) -> List [Dict ]:
96120 """Create parameter groups for optimizer.
@@ -100,7 +124,7 @@ def _get_param_groups(
100124 Args:
101125 model_chunks (List[MegatronModule]): model chunks to create parameter
102126 groups for.
103- config (OptimizerConfig): optimizer configuration object.
127+ optimizer_config (OptimizerConfig): optimizer configuration object.
104128 config_overrides (Optional[Dict[LayerKey, OptimizerConfig]): optimizer overrides,
105129 specified on a per-layer basis.
106130 Returns:
@@ -119,7 +143,7 @@ def _get_param_groups(
119143 uses_default_config = False
120144 # Get optimizer config for this parameter.
121145 if config_overrides is None :
122- config_for_param = config
146+ config_for_param = optimizer_config
123147 uses_default_config = True
124148 else :
125149 config_for_param = None
@@ -129,15 +153,16 @@ def _get_param_groups(
129153 break
130154 # Fall back to default config.
131155 if config_for_param is None :
132- config_for_param = config
156+ config_for_param = optimizer_config
133157 uses_default_config = True
134158
135159 is_expert_parallel = not getattr (param , 'allreduce' , True )
136160
137161 # TODO: Make sure there is a way to support old no_weight_decay_func functionality
138162 # and default_skip_embedding_weight_decay:
139163 # or (default_skip_embedding_weight_decay and "embedding" in name)
140- no_wd = name .endswith (".bias" ) or len (param .shape ) == 1
164+ no_wd_cond_fn = _get_no_wd_cond_fn (optimizer_config .no_weight_decay_cond )
165+ no_wd = no_wd_cond_fn (name , param )
141166 if not no_wd :
142167 wd_mult = 1.0
143168 else :
@@ -173,12 +198,12 @@ def _get_param_groups(
173198 for key in params_key :
174199 wd_mult , is_expert_parallel , _ = key
175200 params = params_map [key ] if key in params_map else []
176- config , uses_default_config = None , True
201+ param_config , uses_default_config = None , True
177202 if key not in configs_map :
178203 assert params == []
179204 else :
180- config , uses_default_config = configs_map [key ]
181- assert config is not None
205+ param_config , uses_default_config = configs_map [key ]
206+ assert param_config is not None
182207
183208 # TODO: Remove "backwards compatible" fields below eventually.
184209 param_group = {
@@ -191,9 +216,9 @@ def _get_param_groups(
191216 }
192217
193218 # Stick relevant fields into param_group from config object.
194- if config is not None :
195- param_group ['max_lr' ] = config .lr
196- param_group ['min_lr' ] = config .min_lr
219+ if param_config is not None :
220+ param_group ['max_lr' ] = param_config .lr
221+ param_group ['min_lr' ] = param_config .min_lr
197222 # TODO: Add other relevant arguments (e.g., weight decay, optimizer)
198223 # here as well.
199224 param_groups .append (param_group )
0 commit comments