Skip to content

Commit cdfe807

Browse files
authored
Merge pull request #147 from kozistr/feature/pid-optimizer
[Feature] Implement PID optimizer
2 parents fab0d29 + 93c5dd9 commit cdfe807

File tree

7 files changed

+134
-3
lines changed

7 files changed

+134
-3
lines changed

README.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pytorch-optimizer
1616

1717
| **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
1818
| I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
19-
| Currently, about 40 optimizers, 6 lr schedulers are supported!
19+
| Currently, 43 optimizers, 6 lr schedulers are supported!
2020
|
2121
| Highly inspired by `pytorch-optimizer <https://github.com/jettify/pytorch-optimizer>`__.
2222
@@ -171,6 +171,8 @@ You can check the supported optimizers & lr schedulers.
171171
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
172172
| QHAdam | *Quasi-hyperbolic momentum and Adam for deep learning* | `github <https://github.com/facebookresearch/qhoptim>`__ | `https://arxiv.org/abs/1810.06801 <https://arxiv.org/abs/1810.06801>`__ |
173173
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
174+
| PID | *A PID Controller Approach for Stochastic Optimization of Deep Networks* | `github <https://github.com/tensorboy/PIDOptimizer>`__ | `CVPR 18 <http://www4.comp.polyu.edu.hk/~cslzhang/paper/CVPR18_PID.pdf>`__ |
175+
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
174176

175177
Useful Resources
176178
----------------
@@ -404,6 +406,8 @@ Citations
404406

405407
`QHAdam <https://github.com/facebookresearch/qhoptim#reference>`__
406408

409+
`PID <https://github.com/tensorboy/PIDOptimizer#citation>`__
410+
407411
Citation
408412
--------
409413

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "2.7.0"
3+
version = "2.8.0"
44
description = "optimizer & lr scheduler implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from pytorch_optimizer.optimizer.nero import Nero
4747
from pytorch_optimizer.optimizer.novograd import NovoGrad
4848
from pytorch_optimizer.optimizer.pcgrad import PCGrad
49+
from pytorch_optimizer.optimizer.pid import PID
4950
from pytorch_optimizer.optimizer.pnm import PNM
5051
from pytorch_optimizer.optimizer.qhadam import QHAdam
5152
from pytorch_optimizer.optimizer.qhm import QHM
@@ -89,6 +90,7 @@
8990
OPTIMIZER_LIST: List[OPTIMIZER] = [
9091
AdaBelief,
9192
AdaBound,
93+
PID,
9294
AdamP,
9395
Adai,
9496
Adan,

pytorch_optimizer/optimizer/pid.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import torch
2+
from torch.optim.optimizer import Optimizer
3+
4+
from pytorch_optimizer.base.exception import NoSparseGradientError
5+
from pytorch_optimizer.base.optimizer import BaseOptimizer
6+
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS
7+
8+
9+
class PID(Optimizer, BaseOptimizer):
10+
r"""A PID Controller Approach for Stochastic Optimization of Deep Networks.
11+
12+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
13+
:param lr: float. learning rate.
14+
:param momentum: float. momentum factor.
15+
:param dampening: float. dampening for momentum.
16+
:param derivative: float. D part of the PID.
17+
:param integral: float. I part of the PID.
18+
:param weight_decay: float. weight decay (L2 penalty).
19+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
20+
"""
21+
22+
def __init__(
23+
self,
24+
params: PARAMETERS,
25+
lr: float = 1e-3,
26+
momentum: float = 0.0,
27+
dampening: float = 0.0,
28+
derivative: float = 10.0,
29+
integral: float = 5.0,
30+
weight_decay: float = 0.0,
31+
weight_decouple: bool = False,
32+
):
33+
self.lr = lr
34+
self.momentum = momentum
35+
self.dampening = dampening
36+
self.derivative = derivative
37+
self.integral = integral
38+
self.weight_decay = weight_decay
39+
40+
self.validate_parameters()
41+
42+
defaults: DEFAULTS = {
43+
'lr': lr,
44+
'momentum': momentum,
45+
'dampening': dampening,
46+
'derivative': derivative,
47+
'integral': integral,
48+
'weight_decay': weight_decay,
49+
'weight_decouple': weight_decouple,
50+
}
51+
super().__init__(params, defaults)
52+
53+
def validate_parameters(self):
54+
self.validate_learning_rate(self.lr)
55+
self.validate_momentum(self.momentum)
56+
self.validate_weight_decay(self.weight_decay)
57+
58+
def __str__(self) -> str:
59+
return 'PID'
60+
61+
@torch.no_grad()
62+
def reset(self):
63+
for group in self.param_groups:
64+
group['step'] = 0
65+
for p in group['params']:
66+
state = self.state[p]
67+
68+
if group['momentum'] > 0.0:
69+
state['grad_buffer'] = torch.zeros_like(p)
70+
state['i_buffer'] = torch.zeros_like(p)
71+
state['d_buffer'] = torch.zeros_like(p)
72+
73+
@torch.no_grad()
74+
def step(self, closure: CLOSURE = None) -> LOSS:
75+
loss: LOSS = None
76+
if closure is not None:
77+
with torch.enable_grad():
78+
loss = closure()
79+
80+
for group in self.param_groups:
81+
if 'step' in group:
82+
group['step'] += 1
83+
else:
84+
group['step'] = 1
85+
86+
for p in group['params']:
87+
if p.grad is None:
88+
continue
89+
90+
grad = p.grad
91+
if grad.is_sparse:
92+
raise NoSparseGradientError(str(self))
93+
94+
state = self.state[p]
95+
96+
if len(state) == 0 and group['momentum'] > 0.0:
97+
state['grad_buffer'] = torch.zeros_like(p)
98+
state['i_buffer'] = torch.zeros_like(p)
99+
state['d_buffer'] = torch.zeros_like(p)
100+
101+
if group['weight_decouple']:
102+
p.mul_(1.0 - group['weight_decay'] * group['lr'])
103+
elif group['weight_decay'] > 0.0:
104+
grad.add_(p, alpha=group['weight_decay'])
105+
106+
if group['momentum'] > 0.0:
107+
i_buf = state['i_buffer']
108+
i_buf.mul_(group['momentum']).add_(grad, alpha=1.0 - group['dampening'])
109+
110+
g_buf, d_buf = state['grad_buffer'], state['d_buffer']
111+
d_buf.mul_(group['momentum'])
112+
113+
if group['step'] > 1:
114+
d_buf.add_(grad - g_buf, alpha=1.0 - group['momentum'])
115+
g_buf.copy_(grad)
116+
117+
grad.add_(i_buf, alpha=group['integral']).add_(d_buf, alpha=group['derivative'])
118+
119+
p.add_(grad, alpha=-group['lr'])
120+
121+
return loss

tests/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
MADGRAD,
77
MSVAG,
88
OPTIMIZERS,
9+
PID,
910
PNM,
1011
QHM,
1112
SGDP,
@@ -346,6 +347,8 @@
346347
(QHAdam, {'lr': 1e0, 'weight_decay': 1e-3, 'weight_decouple': True}, 5),
347348
(QHM, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
348349
(QHM, {'lr': 1e0, 'weight_decay': 1e-3, 'weight_decouple': True}, 5),
350+
(PID, {'lr': 1e0, 'momentum': 0.9, 'dampening': 1.0, 'weight_decay': 1e-3}, 5),
351+
(PID, {'lr': 1e0, 'momentum': 0.9, 'dampening': 1.0, 'weight_decay': 1e-3, 'weight_decouple': True}, 5),
349352
]
350353
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
351354
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),

tests/test_general_optimizer_parameters.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def test_epsilon(optimizer_name):
3636
'msvag',
3737
'aggmo',
3838
'qhm',
39+
'pid',
3940
):
4041
pytest.skip(f'skip {optimizer_name} optimizer')
4142

tests/test_load_optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ def test_load_optimizers_invalid(invalid_optimizer_names):
1616

1717

1818
def test_get_supported_optimizers():
19-
assert len(get_supported_optimizers()) == 42
19+
assert len(get_supported_optimizers()) == 43

0 commit comments

Comments
 (0)