Skip to content

Commit 2e41dfa

Browse files
committed
refactor: types
1 parent a6f2e5d commit 2e41dfa

File tree

8 files changed

+75
-49
lines changed

8 files changed

+75
-49
lines changed

pytorch_optimizer/adamp.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,26 @@
11
import math
2-
from typing import Any, Callable, Dict, List, Optional, Tuple
2+
from typing import Callable, List, Tuple
33

44
import torch
55
import torch.nn.functional as F
66
from torch.optim.optimizer import Optimizer
77

8+
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULT_PARAMETERS, LOSS
9+
810

911
class AdamP(Optimizer):
1012
def __init__(
1113
self,
1214
params,
1315
lr: float = 1e-3,
14-
betas: Tuple[float, float] = (0.9, 0.999),
16+
betas: BETAS = (0.9, 0.999),
1517
eps: float = 1e-8,
1618
weight_decay: float = 0.0,
1719
delta: float = 0.1,
1820
wd_ratio: float = 0.1,
1921
nesterov: bool = False,
2022
):
21-
defaults: Dict[str, Any] = dict(
23+
defaults: DEFAULT_PARAMETERS = dict(
2224
lr=lr,
2325
betas=betas,
2426
eps=eps,
@@ -39,7 +41,10 @@ def layer_view(x: torch.Tensor) -> torch.Tensor:
3941

4042
@staticmethod
4143
def cosine_similarity(
42-
x: torch.Tensor, y: torch.Tensor, eps: float, view_func: Callable
44+
x: torch.Tensor,
45+
y: torch.Tensor,
46+
eps: float,
47+
view_func: Callable[[torch.Tensor], torch.Tensor],
4348
):
4449
x = view_func(x)
4550
y = view_func(y)
@@ -74,8 +79,8 @@ def projection(
7479

7580
return perturb, wd
7681

77-
def step(self, closure: Optional[Callable] = None) -> float:
78-
loss: Optional[float] = None
82+
def step(self, closure: CLOSURE = None) -> LOSS:
83+
loss: LOSS = None
7984
if closure is not None:
8085
loss = closure()
8186

@@ -114,7 +119,6 @@ def step(self, closure: Optional[Callable] = None) -> float:
114119
else:
115120
perturb = exp_avg / denom
116121

117-
# Projection
118122
wd_ratio: float = 1
119123
if len(p.shape) > 1:
120124
perturb, wd_ratio = self.projection(

pytorch_optimizer/lookahead.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,27 @@
11
from collections import defaultdict
2-
from typing import Callable, Dict, List, Optional
2+
from typing import Dict
33

44
import torch
55
from torch.optim import Optimizer
66

7+
from pytorch_optimizer.types import (
8+
CLOSURE,
9+
LOSS,
10+
PARAM_GROUP,
11+
PARAM_GROUPS,
12+
STATE,
13+
)
14+
715

816
class Lookahead(Optimizer):
917
def __init__(self, optimizer: Optimizer, k: int = 5, alpha: float = 0.5):
1018
self.optimizer = optimizer
1119
self.k = k
1220
self.alpha = alpha
1321

14-
self.param_groups: List[Dict] = self.optimizer.param_groups
15-
self.fast_state: Dict = self.optimizer.state
16-
self.state = defaultdict(dict)
22+
self.param_groups: PARAM_GROUPS = self.optimizer.param_groups
23+
self.fast_state: STATE = self.optimizer.state
24+
self.state: STATE = defaultdict(dict)
1725

1826
for group in self.param_groups:
1927
group['counter'] = 0
@@ -32,8 +40,8 @@ def update_lookahead(self):
3240
for group in self.param_groups:
3341
self.update(group)
3442

35-
def step(self, closure: Optional[Callable] = None) -> float:
36-
loss: float = self.optimizer.step(closure)
43+
def step(self, closure: CLOSURE = None) -> LOSS:
44+
loss: LOSS = self.optimizer.step(closure)
3745
for group in self.param_groups:
3846
if group['counter'] == 0:
3947
self.update(group)
@@ -42,12 +50,12 @@ def step(self, closure: Optional[Callable] = None) -> float:
4250
group['counter'] = 0
4351
return loss
4452

45-
def state_dict(self) -> Dict[str, torch.Tensor]:
46-
fast_state_dict = self.optimizer.state_dict()
53+
def state_dict(self) -> STATE:
54+
fast_state_dict: STATE = self.optimizer.state_dict()
4755
fast_state = fast_state_dict['state']
4856
param_groups = fast_state_dict['param_groups']
4957

50-
slow_state: Dict[int, torch.Tensor] = {
58+
slow_state: STATE = {
5159
(id(k) if isinstance(k, torch.Tensor) else k): v
5260
for k, v in self.state.items()
5361
}
@@ -58,12 +66,12 @@ def state_dict(self) -> Dict[str, torch.Tensor]:
5866
'param_groups': param_groups,
5967
}
6068

61-
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
62-
slow_state_dict: Dict[str, torch.Tensor] = {
69+
def load_state_dict(self, state_dict: STATE):
70+
slow_state_dict: STATE = {
6371
'state': state_dict['slow_state'],
6472
'param_groups': state_dict['param_groups'],
6573
}
66-
fast_state_dict: Dict[str, torch.Tensor] = {
74+
fast_state_dict: STATE = {
6775
'state': state_dict['fast_state'],
6876
'param_groups': state_dict['param_groups'],
6977
}
@@ -72,6 +80,6 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
7280
self.optimizer.load_state_dict(fast_state_dict)
7381
self.fast_state = self.optimizer.state
7482

75-
def add_param_group(self, param_group: Dict):
83+
def add_param_group(self, param_group: PARAM_GROUP):
7684
param_group['counter'] = 0
7785
self.optimizer.add_param_group(param_group)

pytorch_optimizer/madgrad.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import math
2-
from typing import Any, Callable, Dict, Optional
32

43
import torch
54
from torch.optim import Optimizer
65

6+
from pytorch_optimizer.types import CLOSURE, DEFAULT_PARAMETERS, LOSS
7+
78

89
class MADGRAD(Optimizer):
910
"""
@@ -26,7 +27,7 @@ def __init__(
2627

2728
self.check_valid_parameters()
2829

29-
defaults: Dict[str, Any] = dict(
30+
defaults: DEFAULT_PARAMETERS = dict(
3031
lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay
3132
)
3233
super().__init__(params, defaults)
@@ -49,15 +50,13 @@ def supports_memory_efficient_fp16(self) -> bool:
4950
def supports_flat_params(self) -> bool:
5051
return True
5152

52-
def step(
53-
self, closure: Optional[Callable[[], float]] = None
54-
) -> Optional[float]:
53+
def step(self, closure: CLOSURE = None) -> LOSS:
5554
"""Performs a single optimization step.
5655
Arguments:
5756
closure (callable, optional): A closure that reevaluates the model
5857
and returns the loss.
5958
"""
60-
loss: Optional[float] = None
59+
loss: LOSS = None
6160
if closure is not None:
6261
loss = closure()
6362

pytorch_optimizer/radam.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import math
2-
from typing import Any, Callable, Dict, Optional, Tuple
2+
from typing import Dict
33

44
import torch
55
from torch.optim.optimizer import Optimizer
66

7+
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULT_PARAMETERS, LOSS
8+
79

810
class RAdam(Optimizer):
911
"""
@@ -15,7 +17,7 @@ def __init__(
1517
self,
1618
params,
1719
lr: float = 1e-3,
18-
betas: Tuple[float, float] = (0.9, 0.999),
20+
betas: BETAS = (0.9, 0.999),
1921
eps: float = 1e-8,
2022
weight_decay: float = 0.0,
2123
n_sma_threshold: int = 5,
@@ -42,7 +44,7 @@ def __init__(
4244
):
4345
param['buffer'] = [[None, None, None] for _ in range(10)]
4446

45-
defaults: Dict[str, Any] = dict(
47+
defaults: DEFAULT_PARAMETERS = dict(
4648
lr=lr,
4749
betas=betas,
4850
eps=eps,
@@ -67,8 +69,8 @@ def check_valid_parameters(self):
6769
def __setstate__(self, state: Dict):
6870
super().__setstate__(state)
6971

70-
def step(self, closure: Optional[Callable] = None) -> float:
71-
loss: Optional[float] = None
72+
def step(self, closure: CLOSURE = None) -> LOSS:
73+
loss: LOSS = None
7274
if closure is not None:
7375
loss = closure()
7476

pytorch_optimizer/ranger.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
import math
2-
from typing import Any, Callable, Dict, List, Optional, Tuple
2+
from typing import Dict
33

44
import torch
55
from torch.optim.optimizer import Optimizer
66

7+
from pytorch_optimizer.types import (
8+
BETAS,
9+
BUFFER,
10+
CLOSURE,
11+
DEFAULT_PARAMETERS,
12+
LOSS,
13+
)
14+
715

816
class Ranger(Optimizer):
917
"""
@@ -21,7 +29,7 @@ def __init__(
2129
alpha: float = 0.5,
2230
k: int = 6,
2331
n_sma_threshold: int = 5,
24-
betas: Tuple[float, float] = (0.95, 0.999),
32+
betas: BETAS = (0.95, 0.999),
2533
eps: float = 1e-5,
2634
weight_decay: float = 0.0,
2735
use_gc: bool = True,
@@ -37,13 +45,11 @@ def __init__(
3745
self.use_gc = use_gc
3846

3947
self.gc_gradient_threshold: int = 3 if gc_conv_only else 1
40-
self.buffer: List[List[Optional[torch.Tensor]]] = [
41-
[None, None, None] for _ in range(10)
42-
]
48+
self.buffer: BUFFER = [[None, None, None] for _ in range(10)]
4349

4450
self.check_valid_parameters()
4551

46-
defaults: Dict[str, Any] = dict(
52+
defaults: DEFAULT_PARAMETERS = dict(
4753
lr=lr,
4854
alpha=alpha,
4955
k=k,
@@ -72,8 +78,8 @@ def check_valid_parameters(self):
7278
def __setstate__(self, state: Dict):
7379
super().__setstate__(state)
7480

75-
def step(self, _: Optional[Callable] = None) -> float:
76-
loss: Optional[float] = None
81+
def step(self, _: CLOSURE = None) -> LOSS:
82+
loss: LOSS = None
7783

7884
for group in self.param_groups:
7985
for p in group['params']:

pytorch_optimizer/ranger21.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import collections
1111
import math
12-
from typing import Any, Callable, Dict, List, Optional, Tuple
12+
from typing import Dict, List, Optional
1313

1414
import numpy as np
1515
import torch
@@ -19,6 +19,7 @@
1919
from pytorch_optimizer.agc import agc
2020
from pytorch_optimizer.chebyshev_schedule import get_chebyshev_schedule
2121
from pytorch_optimizer.gc import centralize_gradient
22+
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULT_PARAMETERS, LOSS
2223
from pytorch_optimizer.utils import normalize_gradient, unit_norm
2324

2425

@@ -43,7 +44,7 @@ def __init__(
4344
use_adaptive_gradient_clipping: bool = True,
4445
agc_clipping_value: float = 1e-2,
4546
agc_eps: float = 1e-3,
46-
betas: Tuple[float, float] = (0.9, 0.999),
47+
betas: BETAS = (0.9, 0.999),
4748
momentum_type: str = 'pnm',
4849
pnm_momentum_factor: float = 1.0,
4950
momentum: float = 0.9,
@@ -62,7 +63,7 @@ def __init__(
6263
warmup_pct_default: float = 0.22,
6364
logging_active: bool = True,
6465
):
65-
defaults: Dict[str, Any] = dict(
66+
defaults: DEFAULT_PARAMETERS = dict(
6667
lr=lr,
6768
momentum=momentum,
6869
betas=betas,
@@ -313,8 +314,8 @@ def get_state_values(group, state):
313314
return beta1, beta2, mean_avg, variance_avg
314315

315316
@torch.no_grad()
316-
def step(self, closure: Optional[Callable] = None):
317-
loss = None
317+
def step(self, closure: CLOSURE = None) -> LOSS:
318+
loss: LOSS = None
318319
if closure is not None and isinstance(closure, collections.Callable):
319320
with torch.enable_grad():
320321
loss = closure()

pytorch_optimizer/sgdp.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import math
2-
from typing import Any, Callable, Dict, List, Optional, Tuple
2+
from typing import Callable, List, Tuple
33

44
import torch
55
import torch.nn.functional as F
66
from torch.optim.optimizer import Optimizer
77

8+
from pytorch_optimizer.types import CLOSURE, DEFAULT_PARAMETERS, LOSS
9+
810

911
class SGDP(Optimizer):
1012
def __init__(
@@ -19,7 +21,7 @@ def __init__(
1921
delta: float = 0.1,
2022
wd_ratio: float = 0.1,
2123
):
22-
defaults: Dict[str, Any] = dict(
24+
defaults: DEFAULT_PARAMETERS = dict(
2325
lr=lr,
2426
momentum=momentum,
2527
dampening=dampening,
@@ -41,7 +43,10 @@ def layer_view(x: torch.Tensor) -> torch.Tensor:
4143

4244
@staticmethod
4345
def cosine_similarity(
44-
x: torch.Tensor, y: torch.Tensor, eps: float, view_func: Callable
46+
x: torch.Tensor,
47+
y: torch.Tensor,
48+
eps: float,
49+
view_func: Callable[[torch.Tensor], torch.Tensor],
4550
):
4651
x = view_func(x)
4752
y = view_func(y)
@@ -76,8 +81,8 @@ def projection(
7681

7782
return perturb, wd
7883

79-
def step(self, closure: Optional[Callable] = None) -> float:
80-
loss: Optional[float] = None
84+
def step(self, closure: CLOSURE = None) -> LOSS:
85+
loss: LOSS = None
8186
if closure is not None:
8287
loss = closure()
8388

pytorch_optimizer/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@
1010
PARAM_GROUP = Dict
1111
PARAM_GROUPS = List[PARAM_GROUP]
1212
STATE = Dict[str, Any]
13+
BUFFER = List[List[Optional[torch.Tensor]]]

0 commit comments

Comments
 (0)