Skip to content

Commit d657060

Browse files
authored
Merge pull request #54 from kozistr/refactor/optimizers
[Refactor] Optimizers
2 parents 6db0d49 + 619c169 commit d657060

File tree

16 files changed

+52
-62
lines changed

16 files changed

+52
-62
lines changed

lint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def get_configuration() -> Namespace:
1414
parser.add_argument(
1515
'-t',
1616
'--threshold',
17-
default=9.9,
17+
default=9.95,
1818
type=float,
1919
)
2020

pytorch_optimizer/adabelief.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,6 @@ def validate_parameters(self):
8282
self.validate_weight_decay(self.weight_decay)
8383
self.validate_epsilon(self.eps)
8484

85-
def __setstate__(self, state: STATE):
86-
super().__setstate__(state)
87-
for group in self.param_groups:
88-
group.setdefault('amsgrad', False)
89-
group.setdefault('adamd_debias_term', False)
90-
9185
@torch.no_grad()
9286
def reset(self):
9387
for group in self.param_groups:
@@ -152,11 +146,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
152146
grad_residual = grad - exp_avg
153147
exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1.0 - beta2)
154148

149+
exp_avg_var = exp_avg_var.add_(group['eps'])
155150
if group['amsgrad']:
156-
max_exp_avg_var = torch.max(state['max_exp_avg_var'], exp_avg_var.add_(group['eps']))
157-
de_nom = (max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
158-
else:
159-
de_nom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
151+
exp_avg_var = torch.max(state['max_exp_avg_var'], exp_avg_var)
152+
153+
de_nom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
160154

161155
if not self.rectify:
162156
step_size = group['lr']

pytorch_optimizer/adabound.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,6 @@ def validate_parameters(self):
8080
self.validate_weight_decay(self.weight_decay)
8181
self.validate_epsilon(self.eps)
8282

83-
def __setstate__(self, state: STATE):
84-
super().__setstate__(state)
85-
for group in self.param_groups:
86-
group.setdefault('amsbound', False)
87-
8883
@torch.no_grad()
8984
def reset(self):
9085
for group in self.param_groups:
@@ -140,10 +135,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
140135
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
141136

142137
if group['amsbound']:
143-
max_exp_avg_sq = torch.max(state['max_exp_avg_sq'], exp_avg_sq)
144-
de_nom = max_exp_avg_sq.sqrt().add_(group['eps'])
145-
else:
146-
de_nom = exp_avg_sq.sqrt().add_(group['eps'])
138+
exp_avg_sq = torch.max(state['max_exp_avg_sq'], exp_avg_sq)
139+
140+
de_nom = exp_avg_sq.sqrt().add_(group['eps'])
147141

148142
bias_correction1 = 1.0 - beta1 ** state['step']
149143
bias_correction2 = 1.0 - beta2 ** state['step']

pytorch_optimizer/diffgrad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch.optim.optimizer import Optimizer
55

66
from pytorch_optimizer.base_optimizer import BaseOptimizer
7-
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS, STATE
7+
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
88

99

1010
class DiffGrad(Optimizer, BaseOptimizer):

pytorch_optimizer/diffrgrad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch.optim.optimizer import Optimizer
55

66
from pytorch_optimizer.base_optimizer import BaseOptimizer
7-
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS, STATE
7+
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
88

99

1010
class DiffRGrad(Optimizer, BaseOptimizer):

pytorch_optimizer/lars.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8989
grad = grad.add(p, alpha=g['weight_decay'])
9090
param_norm = torch.norm(p)
9191
update_norm = torch.norm(grad)
92-
one = torch.ones_like(param_norm)
92+
one = torch.ones_like(param_norm, device=param_norm.device)
9393

9494
q = torch.where(
9595
param_norm > 0.0,
@@ -100,7 +100,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
100100

101101
param_state = self.state[p]
102102
if 'mu' not in param_state:
103-
param_state['mu'] = torch.zeros_like(p)
103+
param_state['mu'] = torch.zeros_like(p, device=p.device)
104104

105105
mu = param_state['mu']
106106
mu.mul_(g['momentum']).add_(grad)

pytorch_optimizer/pcgrad.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,15 @@ def validate_parameters(self):
3939

4040
@torch.no_grad()
4141
def reset(self):
42-
pass
42+
self.zero_grad()
4343

4444
def zero_grad(self):
4545
return self.optimizer.zero_grad(set_to_none=True)
4646

4747
def step(self):
4848
return self.optimizer.step()
4949

50-
def set_grad(self, grads):
50+
def set_grad(self, grads: List[torch.Tensor]):
5151
idx: int = 0
5252
for group in self.optimizer.param_groups:
5353
for p in group['params']:
@@ -74,7 +74,7 @@ def retrieve_grad(self) -> Tuple[List[torch.Tensor], List[int], List[torch.Tenso
7474
def pack_grad(self, objectives: Iterable) -> Tuple[List[torch.Tensor], List[List[int]], List[torch.Tensor]]:
7575
"""pack the gradient of the parameters of the network for each objective
7676
:param objectives: Iterable[nn.Module]. a list of objectives
77-
:return:
77+
:return: torch.Tensor. packed gradients
7878
"""
7979
grads, shapes, has_grads = [], [], []
8080
for objective in objectives:
@@ -89,27 +89,29 @@ def pack_grad(self, objectives: Iterable) -> Tuple[List[torch.Tensor], List[List
8989

9090
return grads, shapes, has_grads
9191

92-
def project_conflicting(self, grads, has_grads) -> torch.Tensor:
92+
def project_conflicting(self, grads: List[torch.Tensor], has_grads: List[torch.Tensor]) -> torch.Tensor:
9393
"""project conflicting
9494
:param grads: a list of the gradient of the parameters
9595
:param has_grads: a list of mask represent whether the parameter has gradient
96-
:return:
96+
:return: torch.Tensor. merged gradients
9797
"""
98-
shared = torch.stack(has_grads).prod(0).bool()
98+
shared: torch.Tensor = torch.stack(has_grads).prod(0).bool()
9999

100-
pc_grad = deepcopy(grads)
100+
pc_grad: List[torch.Tensor] = deepcopy(grads)
101101
for g_i in pc_grad:
102102
random.shuffle(grads)
103103
for g_j in grads:
104-
g_i_g_j = torch.dot(g_i, g_j)
104+
g_i_g_j: torch.Tensor = torch.dot(g_i, g_j)
105105
if g_i_g_j < 0:
106106
g_i -= g_i_g_j * g_j / (g_j.norm() ** 2)
107107

108-
merged_grad = torch.zeros_like(grads[0]).to(grads[0].device)
108+
merged_grad: torch.Tensor = torch.zeros_like(grads[0], device=grads[0].device)
109+
110+
shared_pc_gradients: torch.Tensor = torch.stack([g[shared] for g in pc_grad])
109111
if self.reduction == 'mean':
110-
merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).mean(dim=0)
112+
merged_grad[shared] = shared_pc_gradients.mean(dim=0)
111113
else:
112-
merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).sum(dim=0)
114+
merged_grad[shared] = shared_pc_gradients.sum(dim=0)
113115

114116
merged_grad[~shared] = torch.stack([g[~shared] for g in pc_grad]).sum(dim=0)
115117

@@ -121,7 +123,7 @@ def pc_backward(self, objectives: Iterable[nn.Module]):
121123
:return:
122124
"""
123125
grads, shapes, has_grads = self.pack_grad(objectives)
126+
124127
pc_grad = self.project_conflicting(grads, has_grads)
125128
pc_grad = un_flatten_grad(pc_grad, shapes[0])
126-
127129
self.set_grad(pc_grad)

pytorch_optimizer/radam.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import math
2-
from typing import Dict
32

43
import torch
54
from torch.optim.optimizer import Optimizer
@@ -153,14 +152,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
153152
step_size = -1
154153
buffered[2] = step_size
155154

155+
if group['weight_decay'] != 0 and (n_sma >= self.n_sma_threshold or step_size > 0):
156+
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr'])
157+
156158
if n_sma >= self.n_sma_threshold:
157-
if group['weight_decay'] != 0:
158-
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr'])
159159
de_nom = exp_avg_sq.sqrt().add_(group['eps'])
160160
p_fp32.addcdiv_(exp_avg, de_nom, value=-step_size * group['lr'])
161161
elif step_size > 0:
162-
if group['weight_decay'] != 0:
163-
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr'])
164162
p_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
165163

166164
if p.dtype in (torch.float16, torch.bfloat16):

pytorch_optimizer/ralamb.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def get_gradient_norm(self) -> float:
9494
if p.grad is None:
9595
continue
9696

97-
norm_sq += torch.linalg.norm(p.grad).item() ** 2
97+
norm_sq += torch.linalg.norm(p.grad).cpu().numpy() ** 2
9898

9999
norm = math.sqrt(norm_sq)
100100

@@ -147,7 +147,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
147147
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
148148

149149
state['step'] += 1
150-
buffered = group['buffer'][int(state['step'] % 10)]
150+
buffered = group['buffer'][state['step'] % 10]
151151

152152
bias_correction1 = 1.0 - beta1 ** state['step']
153153

pytorch_optimizer/ranger.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import math
2-
from typing import Dict
32

43
import torch
54
from torch.optim.optimizer import Optimizer

0 commit comments

Comments
 (0)