@@ -198,21 +198,20 @@ def get_optimizer_parameters(
198198 weight_decay : float ,
199199 wd_ban_list : List [str ] = ('bias' , 'LayerNorm.bias' , 'LayerNorm.weight' ),
200200) -> PARAMETERS :
201- r"""Get optimizer parameters while filtering specified modules.
202-
201+ r"""
202+ Get optimizer parameters while filtering specified modules.
203203 :param model_or_parameter: Union[nn.Module, List]. model or parameters.
204204 :param weight_decay: float. weight_decay.
205205 :param wd_ban_list: List[str]. ban list not to set weight decay.
206206 :returns: PARAMETERS. new parameter list.
207207 """
208208
209+
209210 fully_qualified_names = []
210211 for module_name , module in model_or_parameter .named_modules ():
211- for param_name , param in module .named_parameters (recurse = False ):
212+ for param_name , _param in module .named_parameters (recurse = False ):
212213 # Full parameter name includes module and parameter names
213- full_param_name = (
214- f"{ module_name } .{ param_name } " if module_name else param_name
215- )
214+ full_param_name = f'{ module_name } .{ param_name } ' if module_name else param_name
216215 # Check if any ban list substring is in the parameter name or module name
217216 if (
218217 any (banned in param_name for banned in wd_ban_list )
@@ -223,14 +222,20 @@ def get_optimizer_parameters(
223222
224223 if isinstance (model_or_parameter , nn .Module ):
225224 model_or_parameter = list (model_or_parameter .named_parameters ())
226-
225+
227226 return [
228227 {
229- 'params' : [p for n , p in model_or_parameter if p .requires_grad and not any (nd in n for nd in fully_qualified_names )],
228+ 'params' : [
229+ p
230+ for n , p in model_or_parameter
231+ if p .requires_grad and not any (nd in n for nd in fully_qualified_names )
232+ ],
230233 'weight_decay' : weight_decay ,
231234 },
232235 {
233- 'params' : [p for n , p in model_or_parameter if p .requires_grad and any (nd in n for nd in fully_qualified_names )],
236+ 'params' : [
237+ p for n , p in model_or_parameter if p .requires_grad and any (nd in n for nd in fully_qualified_names )
238+ ],
234239 'weight_decay' : 0.0 ,
235240 },
236241 ]
0 commit comments