@@ -198,23 +198,44 @@ def get_optimizer_parameters(
198198 weight_decay : float ,
199199 wd_ban_list : List [str ] = ('bias' , 'LayerNorm.bias' , 'LayerNorm.weight' ),
200200) -> PARAMETERS :
201- r"""Get optimizer parameters while filtering specified modules.
202-
201+ r"""
202+ Get optimizer parameters while filtering specified modules.
203203 :param model_or_parameter: Union[nn.Module, List]. model or parameters.
204204 :param weight_decay: float. weight_decay.
205205 :param wd_ban_list: List[str]. ban list not to set weight decay.
206206 :returns: PARAMETERS. new parameter list.
207207 """
208+
209+
210+ fully_qualified_names = []
211+ for module_name , module in model_or_parameter .named_modules ():
212+ for param_name , _param in module .named_parameters (recurse = False ):
213+ # Full parameter name includes module and parameter names
214+ full_param_name = f'{ module_name } .{ param_name } ' if module_name else param_name
215+ # Check if any ban list substring is in the parameter name or module name
216+ if (
217+ any (banned in param_name for banned in wd_ban_list )
218+ or any (banned in module_name for banned in wd_ban_list )
219+ or any (banned in module ._get_name () for banned in wd_ban_list )
220+ ):
221+ fully_qualified_names .append (full_param_name )
222+
208223 if isinstance (model_or_parameter , nn .Module ):
209224 model_or_parameter = list (model_or_parameter .named_parameters ())
210225
211226 return [
212227 {
213- 'params' : [p for n , p in model_or_parameter if p .requires_grad and not any (nd in n for nd in wd_ban_list )],
228+ 'params' : [
229+ p
230+ for n , p in model_or_parameter
231+ if p .requires_grad and not any (nd in n for nd in fully_qualified_names )
232+ ],
214233 'weight_decay' : weight_decay ,
215234 },
216235 {
217- 'params' : [p for n , p in model_or_parameter if p .requires_grad and any (nd in n for nd in wd_ban_list )],
236+ 'params' : [
237+ p for n , p in model_or_parameter if p .requires_grad and any (nd in n for nd in fully_qualified_names )
238+ ],
218239 'weight_decay' : 0.0 ,
219240 },
220241 ]
0 commit comments