Skip to content

Commit d00136f

Browse files
authored
Merge pull request #261 from kozistr/refactor/code
[Refactor] code
2 parents 474510f + 59e3ec3 commit d00136f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+204
-240
lines changed

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ check:
1616
ruff check pytorch_optimizer examples tests hubconf.py
1717

1818
requirements:
19-
python -m poetry export -f requirements.txt --output requirements.txt --without-hashes
20-
python -m poetry export -f requirements.txt --output requirements-dev.txt --without-hashes --with dev
19+
poetry export -f requirements.txt --output requirements.txt --without-hashes
20+
poetry export -f requirements.txt --output requirements-dev.txt --without-hashes --with dev
2121

2222
docs:
2323
mkdocs serve

docs/changelogs/v3.1.0.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,16 @@
1111
* `bnb_paged_adam8bit`, `bnb_paged_adamw8bit`, `bnb_*_*32bit`.
1212
* Improve `power_iteration()` speed up to 40%. (#259)
1313
* Improve `reg_noise()` (E-MCMC) speed up to 120%. (#260)
14+
* Support `disable_lr_scheduler` parameter for `Ranger21` optimizer to disable built-in learning rate scheduler. (#261)
1415

1516
### Refactor
1617

1718
* Refactor `AdamMini` optimizer. (#258)
1819
* Deprecate optional dependency, `bitsandbytes`. (#258)
1920
* Move `get_rms`, `approximate_sq_grad` functions to `BaseOptimizer` for reusability. (#258)
2021
* Refactor `shampoo_utils.py`. (#259)
22+
* Add `debias`, `debias_adam` methods in `BaseOptimizer`. (#261)
23+
* Refactor to use `BaseOptimizer` only, not inherit multiple classes. (#261)
2124

2225
### Bug
2326

pytorch_optimizer/base/optimizer.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,21 @@
33
from typing import List, Optional, Tuple, Union
44

55
import torch
6+
from torch.optim import Optimizer
67

78
from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError
8-
from pytorch_optimizer.base.types import BETAS, HUTCHINSON_G, PARAMETERS, STATE
9+
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, HUTCHINSON_G, LOSS, PARAMETERS, STATE
910

1011

11-
class BaseOptimizer(ABC):
12-
r"""Base optimizer class."""
12+
class BaseOptimizer(ABC, Optimizer):
13+
r"""Base optimizer class. Provides common functionalities for the optimizers."""
14+
15+
def __init__(self, params: PARAMETERS, defaults: DEFAULTS) -> None:
16+
super().__init__(params, defaults)
1317

1418
@staticmethod
1519
@torch.no_grad()
16-
def set_hessian(param_groups: PARAMETERS, state: STATE, hessian: List[torch.Tensor]):
20+
def set_hessian(param_groups: PARAMETERS, state: STATE, hessian: List[torch.Tensor]) -> None:
1721
r"""Set hessian to state from external source. Generally useful when using functorch as a base.
1822
1923
Example:
@@ -45,7 +49,7 @@ def set_hessian(param_groups: PARAMETERS, state: STATE, hessian: List[torch.Tens
4549
i += 1
4650

4751
@staticmethod
48-
def zero_hessian(param_groups: PARAMETERS, state: STATE, pre_zero: bool = True):
52+
def zero_hessian(param_groups: PARAMETERS, state: STATE, pre_zero: bool = True) -> None:
4953
r"""Zero-out hessian.
5054
5155
:param param_groups: PARAMETERS. parameter groups.
@@ -68,7 +72,7 @@ def compute_hutchinson_hessian(
6872
num_samples: int = 1,
6973
alpha: float = 1.0,
7074
distribution: HUTCHINSON_G = 'gaussian',
71-
):
75+
) -> None:
7276
r"""Hutchinson's approximate hessian, added to the state under key `hessian`.
7377
7478
:param param_groups: PARAMETERS. parameter groups.
@@ -110,7 +114,7 @@ def apply_weight_decay(
110114
weight_decouple: bool,
111115
fixed_decay: bool,
112116
ratio: Optional[float] = None,
113-
):
117+
) -> None:
114118
r"""Apply weight decay.
115119
116120
:param p: torch.Tensor. parameter.
@@ -145,6 +149,27 @@ def apply_ams_bound(
145149

146150
return de_nom.sqrt_().add_(eps)
147151

152+
@staticmethod
153+
def debias(beta: float, step: int) -> float:
154+
r"""Adam-style debias correction. Returns `1.0 - beta ** step`.
155+
156+
:param beta: float. beta.
157+
:param step: int. number of step.
158+
"""
159+
return 1.0 - math.pow(beta, step) # fmt: skip
160+
161+
@staticmethod
162+
def debias_beta(beta: float, step: int) -> float:
163+
r"""Apply the Adam-style debias correction into beta.
164+
165+
Simplified version of `\^{beta} = beta * (1.0 - beta ** (step - 1)) / (1.0 - beta ** step)`
166+
167+
:param beta: float. beta.
168+
:param step: int. number of step.
169+
"""
170+
beta_n: float = math.pow(beta, step)
171+
return (beta_n - beta) / (beta_n - 1.0) # fmt: skip
172+
148173
@staticmethod
149174
def apply_adam_debias(adam_debias: bool, step_size: float, bias_correction1: float) -> float:
150175
r"""Apply AdamD variant.
@@ -205,14 +230,14 @@ def get_adanorm_gradient(
205230
:param exp_grad_norm: Optional[torch.Tensor]. exp_grad_norm.
206231
:param r: float. Optional[float]. momentum (ratio).
207232
"""
208-
if not adanorm:
233+
if not adanorm or exp_grad_norm is None:
209234
return grad
210235

211236
grad_norm = torch.linalg.norm(grad)
212237

213238
exp_grad_norm.mul_(r).add_(grad_norm, alpha=1.0 - r)
214239

215-
return grad * exp_grad_norm / grad_norm if exp_grad_norm > grad_norm else grad
240+
return grad.mul(exp_grad_norm).div_(grad_norm) if exp_grad_norm > grad_norm else grad
216241

217242
@staticmethod
218243
def get_rms(x: torch.Tensor) -> float:
@@ -299,5 +324,8 @@ def validate_nus(self, nus: Union[float, Tuple[float, float]]) -> None:
299324
self.validate_range(nus[1], 'nu2', 0.0, 1.0, range_type='[]')
300325

301326
@abstractmethod
302-
def reset(self): # pragma: no cover
327+
def reset(self) -> None: # pragma: no cover
328+
raise NotImplementedError
329+
330+
def step(self, closure: CLOSURE = None) -> LOSS: # pragma: no cover
303331
raise NotImplementedError

pytorch_optimizer/optimizer/a2grad.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22
from typing import Optional
33

44
import torch
5-
from torch.optim.optimizer import Optimizer
65

76
from pytorch_optimizer.base.exception import NoSparseGradientError
87
from pytorch_optimizer.base.optimizer import BaseOptimizer
98
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS
109

1110

12-
class A2Grad(Optimizer, BaseOptimizer):
11+
class A2Grad(BaseOptimizer):
1312
r"""Optimal Adaptive and Accelerated Stochastic Gradient Descent.
1413
1514
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

pytorch_optimizer/optimizer/adabelief.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import math
22

33
import torch
4-
from torch.optim.optimizer import Optimizer
54

65
from pytorch_optimizer.base.exception import NoSparseGradientError
76
from pytorch_optimizer.base.optimizer import BaseOptimizer
87
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
98

109

11-
class AdaBelief(Optimizer, BaseOptimizer):
10+
class AdaBelief(BaseOptimizer):
1211
r"""Adapting Step-sizes by the Belief in Observed Gradients.
1312
1413
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
@@ -101,8 +100,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
101100

102101
beta1, beta2 = group['betas']
103102

104-
bias_correction1: float = 1.0 - beta1 ** group['step']
105-
bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])
103+
bias_correction1: float = self.debias(beta1, group['step'])
104+
bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))
106105

107106
step_size, n_sma = self.get_rectify_step_size(
108107
is_rectify=group['rectify'],

pytorch_optimizer/optimizer/adabound.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22
from typing import List
33

44
import torch
5-
from torch.optim.optimizer import Optimizer
65

76
from pytorch_optimizer.base.exception import NoSparseGradientError
87
from pytorch_optimizer.base.optimizer import BaseOptimizer
98
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
109

1110

12-
class AdaBound(Optimizer, BaseOptimizer):
11+
class AdaBound(BaseOptimizer):
1312
r"""Adaptive Gradient Methods with Dynamic Bound of Learning Rate.
1413
1514
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
@@ -90,8 +89,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9089

9190
beta1, beta2 = group['betas']
9291

93-
bias_correction1: float = 1.0 - beta1 ** group['step']
94-
bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])
92+
bias_correction1: float = self.debias(beta1, group['step'])
93+
bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))
9594

9695
final_lr: float = group['final_lr'] * group['lr'] / base_lr
9796
lower_bound: float = final_lr * (1 - 1 / (group['gamma'] * group['step'] + 1))

pytorch_optimizer/optimizer/adadelta.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import torch
2-
from torch.optim.optimizer import Optimizer
32

43
from pytorch_optimizer.base.exception import NoSparseGradientError
54
from pytorch_optimizer.base.optimizer import BaseOptimizer
65
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS
76

87

9-
class AdaDelta(Optimizer, BaseOptimizer):
8+
class AdaDelta(BaseOptimizer):
109
r"""An Adaptive Learning Rate Method.
1110
1211
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

pytorch_optimizer/optimizer/adafactor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22
from typing import Optional, Tuple
33

44
import torch
5-
from torch.optim.optimizer import Optimizer
65

76
from pytorch_optimizer.base.exception import NoSparseGradientError
87
from pytorch_optimizer.base.optimizer import BaseOptimizer
98
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
109

1110

12-
class AdaFactor(Optimizer, BaseOptimizer):
11+
class AdaFactor(BaseOptimizer):
1312
r"""Adaptive Learning Rates with Sublinear Memory Cost with some tweaks.
1413
1514
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

pytorch_optimizer/optimizer/adahessian.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
from typing import List, Optional
22

33
import torch
4-
from torch.optim.optimizer import Optimizer
54

65
from pytorch_optimizer.base.exception import NoSparseGradientError
76
from pytorch_optimizer.base.optimizer import BaseOptimizer
87
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, HUTCHINSON_G, LOSS, PARAMETERS
98

109

11-
class AdaHessian(Optimizer, BaseOptimizer):
10+
class AdaHessian(BaseOptimizer):
1211
r"""An Adaptive Second Order Optimizer for Machine Learning.
1312
1413
Requires `loss.backward(create_graph=True)` in order to calculate hessians.
@@ -104,8 +103,8 @@ def step(self, closure: CLOSURE = None, hessian: Optional[List[torch.Tensor]] =
104103

105104
beta1, beta2 = group['betas']
106105

107-
bias_correction1: float = 1.0 - beta1 ** group['step']
108-
bias_correction2: float = 1.0 - beta2 ** group['step']
106+
bias_correction1: float = self.debias(beta1, group['step'])
107+
bias_correction2: float = self.debias(beta2, group['step'])
109108

110109
for p in group['params']:
111110
if p.grad is None:

pytorch_optimizer/optimizer/adai.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import math
22

33
import torch
4-
from torch.optim.optimizer import Optimizer
54

65
from pytorch_optimizer.base.exception import NoSparseGradientError, ZeroParameterSizeError
76
from pytorch_optimizer.base.optimizer import BaseOptimizer
87
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
98
from pytorch_optimizer.optimizer.gc import centralize_gradient
109

1110

12-
class Adai(Optimizer, BaseOptimizer):
11+
class Adai(BaseOptimizer):
1312
r"""Disentangling the Effects of Adaptive Learning Rate and Momentum.
1413
1514
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
@@ -105,7 +104,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
105104
if self.use_gc:
106105
centralize_gradient(grad, gc_conv_only=False)
107106

108-
bias_correction2: float = 1.0 - beta2 ** state['step']
107+
bias_correction2: float = self.debias(beta2, state['step'])
109108

110109
if not group['stable_weight_decay'] and group['weight_decay'] > 0.0:
111110
self.apply_weight_decay(
@@ -148,7 +147,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
148147
fixed_decay=group['fixed_decay'],
149148
)
150149

151-
bias_correction2: float = 1.0 - beta2 ** state['step']
150+
bias_correction2: float = self.debias(beta2, state['step'])
152151

153152
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
154153

0 commit comments

Comments
 (0)