@@ -206,40 +206,31 @@ def get_optimizer_parameters(
206206 :returns: PARAMETERS. new parameter list.
207207 """
208208
209- def find_fully_qualified_names (
210- model : nn .Module ,
211- wd_ban_list : List [str ] = ("bias" , "LayerNorm.weight" , "LayerNorm.bias" ),
212- ):
213- names_without_wd = []
214-
215- for module_name , module in model .named_modules ():
216- for param_name , param in module .named_parameters (recurse = False ):
217- # Full parameter name includes module and parameter names
218- full_param_name = (
219- f"{ module_name } .{ param_name } " if module_name else param_name
220- )
221- # Check if any ban list substring is in the parameter name or module name
222- if (
223- any (banned in param_name for banned in wd_ban_list )
224- or any (banned in module_name for banned in wd_ban_list )
225- or any (banned in module ._get_name () for banned in wd_ban_list )
226- ):
227- names_without_wd .append (full_param_name )
228-
229- return names_without_wd
230-
231- full_names = find_fully_qualified_names (model_or_parameter , wd_ban_list )
209+ fully_qualified_names = []
210+ for module_name , module in model_or_parameter .named_modules ():
211+ for param_name , param in module .named_parameters (recurse = False ):
212+ # Full parameter name includes module and parameter names
213+ full_param_name = (
214+ f"{ module_name } .{ param_name } " if module_name else param_name
215+ )
216+ # Check if any ban list substring is in the parameter name or module name
217+ if (
218+ any (banned in param_name for banned in wd_ban_list )
219+ or any (banned in module_name for banned in wd_ban_list )
220+ or any (banned in module ._get_name () for banned in wd_ban_list )
221+ ):
222+ fully_qualified_names .append (full_param_name )
232223
233224 if isinstance (model_or_parameter , nn .Module ):
234225 model_or_parameter = list (model_or_parameter .named_parameters ())
235226
236227 return [
237228 {
238- 'params' : [p for n , p in model_or_parameter if p .requires_grad and not any (nd in n for nd in full_names )],
229+ 'params' : [p for n , p in model_or_parameter if p .requires_grad and not any (nd in n for nd in fully_qualified_names )],
239230 'weight_decay' : weight_decay ,
240231 },
241232 {
242- 'params' : [p for n , p in model_or_parameter if p .requires_grad and any (nd in n for nd in full_names )],
233+ 'params' : [p for n , p in model_or_parameter if p .requires_grad and any (nd in n for nd in fully_qualified_names )],
243234 'weight_decay' : 0.0 ,
244235 },
245236 ]
0 commit comments