|
12 | 12 | from torch.testing._internal.common_utils import TestCase |
13 | 13 | from torch.nn import Parameter |
14 | 14 |
|
15 | | -from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class |
| 15 | +from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo |
16 | 16 | from timm.optim import param_groups_layer_decay, param_groups_weight_decay |
17 | 17 | from timm.scheduler import PlateauLRScheduler |
18 | 18 |
|
@@ -294,28 +294,32 @@ def _build_params_dict_single(weight, bias, **kwargs): |
294 | 294 |
|
295 | 295 | @pytest.mark.parametrize('optimizer', list_optimizers(exclude_filters=('fused*', 'bnb*'))) |
296 | 296 | def test_optim_factory(optimizer): |
297 | | - get_optimizer_class(optimizer) |
298 | | - |
299 | | - # test basic cases that don't need specific tuning via factory test |
300 | | - _test_basic_cases( |
301 | | - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) |
302 | | - ) |
303 | | - _test_basic_cases( |
304 | | - lambda weight, bias: create_optimizer_v2( |
305 | | - _build_params_dict(weight, bias, lr=1e-2), |
306 | | - optimizer, |
307 | | - lr=1e-3) |
308 | | - ) |
309 | | - _test_basic_cases( |
310 | | - lambda weight, bias: create_optimizer_v2( |
311 | | - _build_params_dict_single(weight, bias, lr=1e-2), |
312 | | - optimizer, |
313 | | - lr=1e-3) |
314 | | - ) |
315 | | - _test_basic_cases( |
316 | | - lambda weight, bias: create_optimizer_v2( |
317 | | - _build_params_dict_single(weight, bias, lr=1e-2), optimizer) |
318 | | - ) |
| 297 | + assert issubclass(get_optimizer_class(optimizer), torch.optim.Optimizer) |
| 298 | + |
| 299 | + opt_info = get_optimizer_info(optimizer) |
| 300 | + assert isinstance(opt_info, OptimInfo) |
| 301 | + |
| 302 | + if not opt_info.second_order: # basic tests don't support second order right now |
| 303 | + # test basic cases that don't need specific tuning via factory test |
| 304 | + _test_basic_cases( |
| 305 | + lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) |
| 306 | + ) |
| 307 | + _test_basic_cases( |
| 308 | + lambda weight, bias: create_optimizer_v2( |
| 309 | + _build_params_dict(weight, bias, lr=1e-2), |
| 310 | + optimizer, |
| 311 | + lr=1e-3) |
| 312 | + ) |
| 313 | + _test_basic_cases( |
| 314 | + lambda weight, bias: create_optimizer_v2( |
| 315 | + _build_params_dict_single(weight, bias, lr=1e-2), |
| 316 | + optimizer, |
| 317 | + lr=1e-3) |
| 318 | + ) |
| 319 | + _test_basic_cases( |
| 320 | + lambda weight, bias: create_optimizer_v2( |
| 321 | + _build_params_dict_single(weight, bias, lr=1e-2), optimizer) |
| 322 | + ) |
319 | 323 |
|
320 | 324 |
|
321 | 325 | #@pytest.mark.parametrize('optimizer', ['sgd', 'momentum']) |
|
0 commit comments