Skip to content

Commit c6d64ef

Browse files
authored
Merge pull request #79 from kozistr/feature/lr_scheduler
[Feature] Supports LR Schedulers
2 parents dc01144 + c0c7aaa commit c6d64ef

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+499
-131
lines changed

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Install
2929

3030
::
3131

32-
$ pip3 install pytorch-optimizer
32+
$ pip3 install -U pytorch-optimizer
3333

3434
Simple Usage
3535
~~~~~~~~~~~~

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "1.3.2"
3+
version = "2.0.0"
44
description = "Bunch of optimizer implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]
@@ -9,7 +9,7 @@ readme = "README.rst"
99
homepage = "https://github.com/kozistr/pytorch_optimizer"
1010
repository = "https://github.com/kozistr/pytorch_optimizer"
1111
documentation = "https://pytorch-optimizers.readthedocs.io/en/latest"
12-
keywords = ["pytorch", "deep-learning", "optimizer"]
12+
keywords = ["pytorch", "deep-learning", "optimizer", "lr scheduler"]
1313
classifiers = [
1414
"License :: OSI Approved :: Apache Software License",
1515
"Development Status :: 5 - Production/Stable",

pytorch_optimizer/__init__.py

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,42 @@
11
# pylint: disable=unused-import
2-
from typing import Dict, List, Type
2+
from typing import Dict, List
33

4-
from torch.optim import Optimizer
5-
6-
from pytorch_optimizer.adabelief import AdaBelief
7-
from pytorch_optimizer.adabound import AdaBound
8-
from pytorch_optimizer.adamp import AdamP
9-
from pytorch_optimizer.adan import Adan
10-
from pytorch_optimizer.adapnm import AdaPNM
11-
from pytorch_optimizer.agc import agc
12-
from pytorch_optimizer.chebyshev_schedule import get_chebyshev_schedule
13-
from pytorch_optimizer.diffgrad import DiffGrad
14-
from pytorch_optimizer.diffrgrad import DiffRGrad
15-
from pytorch_optimizer.fp16 import DynamicLossScaler, SafeFP16Optimizer
16-
from pytorch_optimizer.gc import centralize_gradient
17-
from pytorch_optimizer.lamb import Lamb
18-
from pytorch_optimizer.lars import LARS
19-
from pytorch_optimizer.lookahead import Lookahead
20-
from pytorch_optimizer.madgrad import MADGRAD
21-
from pytorch_optimizer.nero import Nero
22-
from pytorch_optimizer.pcgrad import PCGrad
23-
from pytorch_optimizer.pnm import PNM
24-
from pytorch_optimizer.radam import RAdam
25-
from pytorch_optimizer.ralamb import RaLamb
26-
from pytorch_optimizer.ranger import Ranger
27-
from pytorch_optimizer.ranger21 import Ranger21
28-
from pytorch_optimizer.sam import SAM
29-
from pytorch_optimizer.sgdp import SGDP
30-
from pytorch_optimizer.shampoo import Shampoo
31-
from pytorch_optimizer.utils import (
4+
from pytorch_optimizer.base.types import OPTIMIZER, SCHEDULER
5+
from pytorch_optimizer.lr_scheduler.chebyshev import get_chebyshev_schedule
6+
from pytorch_optimizer.lr_scheduler.cosine_anealing import CosineAnnealingWarmupRestarts
7+
from pytorch_optimizer.optimizer.adabelief import AdaBelief
8+
from pytorch_optimizer.optimizer.adabound import AdaBound
9+
from pytorch_optimizer.optimizer.adamp import AdamP
10+
from pytorch_optimizer.optimizer.adan import Adan
11+
from pytorch_optimizer.optimizer.adapnm import AdaPNM
12+
from pytorch_optimizer.optimizer.agc import agc
13+
from pytorch_optimizer.optimizer.diffgrad import DiffGrad
14+
from pytorch_optimizer.optimizer.diffrgrad import DiffRGrad
15+
from pytorch_optimizer.optimizer.fp16 import DynamicLossScaler, SafeFP16Optimizer
16+
from pytorch_optimizer.optimizer.gc import centralize_gradient
17+
from pytorch_optimizer.optimizer.lamb import Lamb
18+
from pytorch_optimizer.optimizer.lars import LARS
19+
from pytorch_optimizer.optimizer.lookahead import Lookahead
20+
from pytorch_optimizer.optimizer.madgrad import MADGRAD
21+
from pytorch_optimizer.optimizer.nero import Nero
22+
from pytorch_optimizer.optimizer.pcgrad import PCGrad
23+
from pytorch_optimizer.optimizer.pnm import PNM
24+
from pytorch_optimizer.optimizer.radam import RAdam
25+
from pytorch_optimizer.optimizer.ralamb import RaLamb
26+
from pytorch_optimizer.optimizer.ranger import Ranger
27+
from pytorch_optimizer.optimizer.ranger21 import Ranger21
28+
from pytorch_optimizer.optimizer.sam import SAM
29+
from pytorch_optimizer.optimizer.sgdp import SGDP
30+
from pytorch_optimizer.optimizer.shampoo import Shampoo
31+
from pytorch_optimizer.optimizer.utils import (
3232
clip_grad_norm,
3333
get_optimizer_parameters,
3434
matrix_power,
3535
normalize_gradient,
3636
unit_norm,
3737
)
3838

39-
OPTIMIZER_LIST: List[Type[Optimizer]] = [
39+
OPTIMIZER_LIST: List[OPTIMIZER] = [
4040
AdaBelief,
4141
AdaBound,
4242
AdamP,
@@ -56,10 +56,17 @@
5656
SGDP,
5757
Shampoo,
5858
]
59-
OPTIMIZERS: Dict[str, Type[Optimizer]] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
59+
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
60+
61+
LR_SCHEDULER_LIST: List[SCHEDULER] = [
62+
CosineAnnealingWarmupRestarts,
63+
]
64+
LR_SCHEDULERS: Dict[str, SCHEDULER] = {
65+
str(lr_scheduler.__name__).lower(): lr_scheduler for lr_scheduler in LR_SCHEDULER_LIST
66+
}
6067

6168

62-
def load_optimizer(optimizer: str) -> Type[Optimizer]:
69+
def load_optimizer(optimizer: str) -> OPTIMIZER:
6370
optimizer: str = optimizer.lower()
6471

6572
if optimizer not in OPTIMIZERS:
@@ -68,5 +75,18 @@ def load_optimizer(optimizer: str) -> Type[Optimizer]:
6875
return OPTIMIZERS[optimizer]
6976

7077

71-
def get_supported_optimizers() -> List[Type[Optimizer]]:
78+
def load_lr_scheduler(lr_scheduler: str) -> SCHEDULER:
79+
lr_scheduler: str = lr_scheduler.lower()
80+
81+
if lr_scheduler not in LR_SCHEDULERS:
82+
raise NotImplementedError(f'[-] not implemented lr_scheduler : {lr_scheduler}')
83+
84+
return LR_SCHEDULERS[lr_scheduler]
85+
86+
87+
def get_supported_optimizers() -> List[OPTIMIZER]:
7288
return OPTIMIZER_LIST
89+
90+
91+
def get_supported_lr_schedulers() -> List[SCHEDULER]:
92+
return LR_SCHEDULER_LIST

pytorch_optimizer/base/__init__.py

Whitespace-only changes.

pytorch_optimizer/base_optimizer.py renamed to pytorch_optimizer/base/base_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44

5-
from pytorch_optimizer.types import BETAS
5+
from pytorch_optimizer.base.types import BETAS
66

77

88
class BaseOptimizer(ABC):
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1-
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
1+
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union
22

33
import torch
4+
from torch.optim import Optimizer
5+
from torch.optim.lr_scheduler import _LRScheduler
46

57
CLOSURE = Optional[Callable[[], float]]
68
LOSS = Optional[float]
79
BETAS = Union[Tuple[float, float], Tuple[float, float, float]]
810
DEFAULTS = Dict[str, Any]
911
PARAMETERS = Optional[Union[Iterable[Dict[str, Any]], Iterable[torch.Tensor]]]
1012
STATE = Dict[str, Any]
13+
OPTIMIZER = Type[Optimizer]
14+
SCHEDULER = Type[_LRScheduler]

pytorch_optimizer/experimental/__init__.py

Whitespace-only changes.
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from torch import nn
2+
3+
from pytorch_optimizer.base.types import PARAMETERS
4+
5+
6+
def deberta_v3_large_lr_scheduler(
7+
model: nn.Module,
8+
head_param_start: int = 390,
9+
base_lr: float = 2e-5,
10+
head_lr: float = 1e-4,
11+
wd: float = 1e-2,
12+
) -> PARAMETERS:
13+
"""DeBERTa-v3 large layer-wise lr scheduler
14+
Reference : https://github.com/gilfernandes/commonlit
15+
16+
:param model: nn.Module. model. based on Huggingface Transformers.
17+
:param head_param_start: int. where the backbone ends (head starts)
18+
:param base_lr: float. base lr
19+
:param head_lr: float. head_lr
20+
:param wd: float. weight decay
21+
"""
22+
named_parameters = list(model.named_parameters())
23+
24+
backbone_parameters = named_parameters[:head_param_start]
25+
regressor_parameters = named_parameters[head_param_start:]
26+
27+
regressor_group = [params for (_, params) in regressor_parameters]
28+
29+
parameters = [{'params': regressor_group, 'lr': head_lr}]
30+
31+
layer_low_threshold: int = 195 # start of the 12 layers
32+
layer_middle_threshold: int = 323 # end of the 24 layers
33+
34+
for layer_num, (name, params) in enumerate(backbone_parameters):
35+
weight_decay: float = 0.0 if ('bias' in name) or ('LayerNorm.weight' in name) else wd
36+
37+
lr = base_lr / 2.5 # 2e-5
38+
if layer_num >= layer_middle_threshold:
39+
lr = base_lr / 0.5 # 1e-4
40+
elif layer_num >= layer_low_threshold:
41+
lr = base_lr
42+
43+
parameters.append({'params': params, 'weight_decay': weight_decay, 'lr': lr})
44+
45+
return parameters

pytorch_optimizer/lr_scheduler/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)