Skip to content

Commit 0e6da65

Browse files
committed
More fixes for new factory & tests, add back adahessian
1 parent 5dae918 commit 0e6da65

File tree

6 files changed

+69
-31
lines changed

6 files changed

+69
-31
lines changed

hfdocs/source/reference/optimizers.mdx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@ This page contains the API reference documentation for learning rate optimizers
1818
[[autodoc]] timm.optim.adahessian.Adahessian
1919
[[autodoc]] timm.optim.adamp.AdamP
2020
[[autodoc]] timm.optim.adamw.AdamW
21+
[[autodoc]] timm.optim.adan.Adan
2122
[[autodoc]] timm.optim.adopt.Adopt
2223
[[autodoc]] timm.optim.lamb.Lamb
2324
[[autodoc]] timm.optim.lars.Lars
24-
[[autodoc]] timm.optim.lion,Lion
25+
[[autodoc]] timm.optim.lion.Lion
2526
[[autodoc]] timm.optim.lookahead.Lookahead
2627
[[autodoc]] timm.optim.madgrad.MADGRAD
2728
[[autodoc]] timm.optim.nadam.Nadam

tests/test_optim.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from torch.testing._internal.common_utils import TestCase
1313
from torch.nn import Parameter
1414

15-
from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class
15+
from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo
1616
from timm.optim import param_groups_layer_decay, param_groups_weight_decay
1717
from timm.scheduler import PlateauLRScheduler
1818

@@ -294,28 +294,32 @@ def _build_params_dict_single(weight, bias, **kwargs):
294294

295295
@pytest.mark.parametrize('optimizer', list_optimizers(exclude_filters=('fused*', 'bnb*')))
296296
def test_optim_factory(optimizer):
297-
get_optimizer_class(optimizer)
298-
299-
# test basic cases that don't need specific tuning via factory test
300-
_test_basic_cases(
301-
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
302-
)
303-
_test_basic_cases(
304-
lambda weight, bias: create_optimizer_v2(
305-
_build_params_dict(weight, bias, lr=1e-2),
306-
optimizer,
307-
lr=1e-3)
308-
)
309-
_test_basic_cases(
310-
lambda weight, bias: create_optimizer_v2(
311-
_build_params_dict_single(weight, bias, lr=1e-2),
312-
optimizer,
313-
lr=1e-3)
314-
)
315-
_test_basic_cases(
316-
lambda weight, bias: create_optimizer_v2(
317-
_build_params_dict_single(weight, bias, lr=1e-2), optimizer)
318-
)
297+
assert issubclass(get_optimizer_class(optimizer), torch.optim.Optimizer)
298+
299+
opt_info = get_optimizer_info(optimizer)
300+
assert isinstance(opt_info, OptimInfo)
301+
302+
if not opt_info.second_order: # basic tests don't support second order right now
303+
# test basic cases that don't need specific tuning via factory test
304+
_test_basic_cases(
305+
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
306+
)
307+
_test_basic_cases(
308+
lambda weight, bias: create_optimizer_v2(
309+
_build_params_dict(weight, bias, lr=1e-2),
310+
optimizer,
311+
lr=1e-3)
312+
)
313+
_test_basic_cases(
314+
lambda weight, bias: create_optimizer_v2(
315+
_build_params_dict_single(weight, bias, lr=1e-2),
316+
optimizer,
317+
lr=1e-3)
318+
)
319+
_test_basic_cases(
320+
lambda weight, bias: create_optimizer_v2(
321+
_build_params_dict_single(weight, bias, lr=1e-2), optimizer)
322+
)
319323

320324

321325
#@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])

timm/optim/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@
1717
from .rmsprop_tf import RMSpropTF
1818
from .sgdp import SGDP
1919

20-
from ._optim_factory import list_optimizers, get_optimizer_class, create_optimizer_v2, \
21-
create_optimizer, optimizer_kwargs, OptimInfo, OptimizerRegistry
22-
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay
20+
from ._optim_factory import list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo, OptimizerRegistry, \
21+
create_optimizer_v2, create_optimizer, optimizer_kwargs
22+
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, auto_group_layers

timm/optim/_optim_factory.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
import torch.nn as nn
1414
import torch.optim as optim
1515

16-
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, group_parameters
16+
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay
1717
from .adabelief import AdaBelief
1818
from .adafactor import Adafactor
1919
from .adafactor_bv import AdafactorBigVision
20+
from .adahessian import Adahessian
2021
from .adamp import AdamP
2122
from .adan import Adan
2223
from .adopt import Adopt
@@ -78,6 +79,7 @@ class OptimInfo:
7879
has_momentum: bool = False
7980
has_betas: bool = False
8081
num_betas: int = 2
82+
second_order: bool = False
8183
defaults: Optional[Dict[str, Any]] = None
8284

8385

@@ -540,6 +542,13 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
540542
has_betas=True,
541543
num_betas=3
542544
),
545+
OptimInfo(
546+
name='adahessian',
547+
opt_class=Adahessian,
548+
description='An Adaptive Second Order Optimizer',
549+
has_betas=True,
550+
second_order=True,
551+
),
543552
OptimInfo(
544553
name='lion',
545554
opt_class=Lion,
@@ -770,6 +779,21 @@ def list_optimizers(
770779
return default_registry.list_optimizers(filter, exclude_filters, with_description)
771780

772781

782+
def get_optimizer_info(name: str) -> OptimInfo:
783+
"""Get the OptimInfo for an optimizer.
784+
785+
Args:
786+
name: Name of the optimizer
787+
788+
Returns:
789+
OptimInfo configuration
790+
791+
Raises:
792+
ValueError: If optimizer is not found
793+
"""
794+
return default_registry.get_optimizer_info(name)
795+
796+
773797
def get_optimizer_class(
774798
name: str,
775799
bind_defaults: bool = False,

timm/optim/_param_groups.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from itertools import islice
3-
from typing import Collection, Optional, Tuple
3+
from typing import Collection, Optional
44

55
from torch import nn as nn
66

@@ -37,7 +37,7 @@ def _group(it, size):
3737
return iter(lambda: tuple(islice(it, size)), ())
3838

3939

40-
def _layer_map(model, layers_per_group=12, num_groups=None):
40+
def auto_group_layers(model, layers_per_group=12, num_groups=None):
4141
def _in_head(n, hp):
4242
if not hp:
4343
return True
@@ -63,6 +63,8 @@ def _in_head(n, hp):
6363
layer_map.update({n: num_trunk_groups for n in names_head})
6464
return layer_map
6565

66+
_layer_map = auto_group_layers # backward compat
67+
6668

6769
def param_groups_layer_decay(
6870
model: nn.Module,
@@ -86,7 +88,7 @@ def param_groups_layer_decay(
8688
layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True)
8789
else:
8890
# fallback
89-
layer_map = _layer_map(model)
91+
layer_map = auto_group_layers(model)
9092
num_layers = max(layer_map.values()) + 1
9193
layer_max = num_layers - 1
9294
layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))

timm/optim/optim_factory.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# lots of uses of these functions directly, ala 'import timm.optim.optim_factory as optim_factory', fun :/
2+
3+
from ._optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs
4+
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, group_parameters, _layer_map, _group
5+
6+
import warnings
7+
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.optim", FutureWarning)

0 commit comments

Comments
 (0)