|
1 | 1 | import math |
2 | 2 | import warnings |
3 | 3 | from importlib.util import find_spec |
4 | | -from typing import Callable, Dict, List, Optional, Tuple, Union |
| 4 | +from typing import Callable, Dict, List, Optional, Set, Tuple, Union |
5 | 5 |
|
6 | 6 | import numpy as np |
7 | 7 | import torch |
@@ -198,43 +198,45 @@ def get_optimizer_parameters( |
198 | 198 | weight_decay: float, |
199 | 199 | wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'), |
200 | 200 | ) -> PARAMETERS: |
201 | | - r""" |
202 | | - Get optimizer parameters while filtering specified modules. |
| 201 | + r"""Get optimizer parameters while filtering specified modules. |
| 202 | +
|
| 203 | + Notice that, You can also ban by a module name level (e.g. LayerNorm) if you pass nn.Module instance. You just only |
| 204 | + need to input `LayerNorm` to exclude weight decay from the layer norm layer(s). |
| 205 | +
|
203 | 206 | :param model_or_parameter: Union[nn.Module, List]. model or parameters. |
204 | 207 | :param weight_decay: float. weight_decay. |
205 | 208 | :param wd_ban_list: List[str]. ban list not to set weight decay. |
206 | 209 | :returns: PARAMETERS. new parameter list. |
207 | 210 | """ |
208 | | - |
209 | | - |
210 | | - fully_qualified_names = [] |
211 | | - for module_name, module in model_or_parameter.named_modules(): |
212 | | - for param_name, _param in module.named_parameters(recurse=False): |
213 | | - # Full parameter name includes module and parameter names |
214 | | - full_param_name = f'{module_name}.{param_name}' if module_name else param_name |
215 | | - # Check if any ban list substring is in the parameter name or module name |
216 | | - if ( |
217 | | - any(banned in param_name for banned in wd_ban_list) |
218 | | - or any(banned in module_name for banned in wd_ban_list) |
219 | | - or any(banned in module._get_name() for banned in wd_ban_list) |
220 | | - ): |
221 | | - fully_qualified_names.append(full_param_name) |
| 211 | + banned_parameter_patterns: Set[str] = set() |
222 | 212 |
|
223 | 213 | if isinstance(model_or_parameter, nn.Module): |
| 214 | + for module_name, module in model_or_parameter.named_modules(): |
| 215 | + for param_name, _ in module.named_parameters(recurse=False): |
| 216 | + full_param_name: str = f'{module_name}.{param_name}' if module_name else param_name |
| 217 | + if any( |
| 218 | + banned in pattern for banned in wd_ban_list for pattern in (full_param_name, module._get_name()) |
| 219 | + ): |
| 220 | + banned_parameter_patterns.add(full_param_name) |
| 221 | + |
224 | 222 | model_or_parameter = list(model_or_parameter.named_parameters()) |
| 223 | + else: |
| 224 | + banned_parameter_patterns.update(wd_ban_list) |
225 | 225 |
|
226 | 226 | return [ |
227 | 227 | { |
228 | 228 | 'params': [ |
229 | 229 | p |
230 | 230 | for n, p in model_or_parameter |
231 | | - if p.requires_grad and not any(nd in n for nd in fully_qualified_names) |
| 231 | + if p.requires_grad and not any(nd in n for nd in banned_parameter_patterns) |
232 | 232 | ], |
233 | 233 | 'weight_decay': weight_decay, |
234 | 234 | }, |
235 | 235 | { |
236 | 236 | 'params': [ |
237 | | - p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in fully_qualified_names) |
| 237 | + p |
| 238 | + for n, p in model_or_parameter |
| 239 | + if p.requires_grad and any(nd in n for nd in banned_parameter_patterns) |
238 | 240 | ], |
239 | 241 | 'weight_decay': 0.0, |
240 | 242 | }, |
|
0 commit comments