@@ -205,16 +205,41 @@ def get_optimizer_parameters(
205205 :param wd_ban_list: List[str]. ban list not to set weight decay.
206206 :returns: PARAMETERS. new parameter list.
207207 """
208+
209+ def find_fully_qualified_names (
210+ model : nn .Module ,
211+ wd_ban_list : List [str ] = ("bias" , "LayerNorm.weight" , "LayerNorm.bias" ),
212+ ):
213+ names_without_wd = []
214+
215+ for module_name , module in model .named_modules ():
216+ for param_name , param in module .named_parameters (recurse = False ):
217+ # Full parameter name includes module and parameter names
218+ full_param_name = (
219+ f"{ module_name } .{ param_name } " if module_name else param_name
220+ )
221+ # Check if any ban list substring is in the parameter name or module name
222+ if (
223+ any (banned in param_name for banned in wd_ban_list )
224+ or any (banned in module_name for banned in wd_ban_list )
225+ or any (banned in module ._get_name () for banned in wd_ban_list )
226+ ):
227+ names_without_wd .append (full_param_name )
228+
229+ return names_without_wd
230+
231+ full_names = find_fully_qualified_names (model_or_parameter , wd_ban_list )
232+
208233 if isinstance (model_or_parameter , nn .Module ):
209234 model_or_parameter = list (model_or_parameter .named_parameters ())
210-
235+
211236 return [
212237 {
213- 'params' : [p for n , p in model_or_parameter if p .requires_grad and not any (nd in n for nd in wd_ban_list )],
238+ 'params' : [p for n , p in model_or_parameter if p .requires_grad and not any (nd in n for nd in full_names )],
214239 'weight_decay' : weight_decay ,
215240 },
216241 {
217- 'params' : [p for n , p in model_or_parameter if p .requires_grad and any (nd in n for nd in wd_ban_list )],
242+ 'params' : [p for n , p in model_or_parameter if p .requires_grad and any (nd in n for nd in full_names )],
218243 'weight_decay' : 0.0 ,
219244 },
220245 ]
0 commit comments