Skip to content

Commit 6db0d49

Browse files
authored
Merge pull request #53 from kozistr/refactor/cov
[Refactor] Coverage
2 parents 358ff43 + a23d40b commit 6db0d49

31 files changed

+555
-297
lines changed

.github/CODEOWNERS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
@kozistr
1+
* @kozistr

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ or you can use optimizer loader, simply passing a name of the optimizer.
4242

4343
...
4444
model = YourModel()
45-
opt = load_optimizers(optimizer='adamp', use_fp16=True)
45+
opt = load_optimizers(optimizer='adamp')
4646
optimizer = opt(model.parameters())
4747
...
4848

pytorch_optimizer/adabound.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,18 @@ def __setstate__(self, state: STATE):
8484
super().__setstate__(state)
8585
for group in self.param_groups:
8686
group.setdefault('amsbound', False)
87-
group.setdefault('adamd_debias_term', False)
87+
88+
@torch.no_grad()
89+
def reset(self):
90+
for group in self.param_groups:
91+
for p in group['params']:
92+
state = self.state[p]
93+
94+
state['step'] = 0
95+
state['exp_avg'] = torch.zeros_like(p)
96+
state['exp_avg_sq'] = torch.zeros_like(p)
97+
if group['amsbound']:
98+
state['max_exp_avg_sq'] = torch.zeros_like(p)
8899

89100
@torch.no_grad()
90101
def step(self, closure: CLOSURE = None) -> LOSS:
@@ -127,14 +138,15 @@ def step(self, closure: CLOSURE = None) -> LOSS:
127138

128139
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
129140
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
141+
130142
if group['amsbound']:
131143
max_exp_avg_sq = torch.max(state['max_exp_avg_sq'], exp_avg_sq)
132144
de_nom = max_exp_avg_sq.sqrt().add_(group['eps'])
133145
else:
134146
de_nom = exp_avg_sq.sqrt().add_(group['eps'])
135147

136-
bias_correction1 = 1 - beta1 ** state['step']
137-
bias_correction2 = 1 - beta2 ** state['step']
148+
bias_correction1 = 1.0 - beta1 ** state['step']
149+
bias_correction2 = 1.0 - beta2 ** state['step']
138150

139151
step_size = group['lr'] * math.sqrt(bias_correction2)
140152
if not group['adamd_debias_term']:

pytorch_optimizer/adahessian.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,16 @@ def set_hessian(self):
138138
# approximate the expected values of z * (H@z)
139139
p.hess += h_z * z / self.num_samples
140140

141+
@torch.no_grad()
142+
def reset(self):
143+
for group in self.param_groups:
144+
for p in group['params']:
145+
state = self.state[p]
146+
147+
state['step'] = 0
148+
state['exp_avg'] = torch.zeros_like(p)
149+
state['exp_hessian_diag_sq'] = torch.zeros_like(p)
150+
141151
@torch.no_grad()
142152
def step(self, closure: CLOSURE = None) -> LOSS:
143153
loss: LOSS = None
@@ -171,14 +181,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
171181
beta1, beta2 = group['betas']
172182

173183
# Decay the first and second moment running average coefficient
174-
exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
175-
exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2)
184+
exp_avg.mul_(beta1).add_(p.grad, alpha=1.0 - beta1)
185+
exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1.0 - beta2)
176186

177-
bias_correction1 = 1 - beta1 ** state['step']
178-
bias_correction2 = 1 - beta2 ** state['step']
187+
bias_correction1 = 1.0 - beta1 ** state['step']
188+
bias_correction2 = 1.0 - beta2 ** state['step']
179189

180-
hessian_power = group['hessian_power']
181-
de_nom = (exp_hessian_diag_sq / bias_correction2).pow_(hessian_power / 2.0).add_(group['eps'])
190+
de_nom = (exp_hessian_diag_sq / bias_correction2).pow_(group['hessian_power'] / 2.0).add_(group['eps'])
182191

183192
step_size = group['lr']
184193
if not group['adamd_debias_term']:

pytorch_optimizer/adamp.py

Lines changed: 11 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import math
2-
from typing import Callable, List, Tuple
32

43
import torch
5-
import torch.nn.functional as F
64
from torch.optim.optimizer import Optimizer
75

86
from pytorch_optimizer.base_optimizer import BaseOptimizer
97
from pytorch_optimizer.gc import centralize_gradient
108
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
9+
from pytorch_optimizer.utils import projection
1110

1211

1312
class AdamP(Optimizer, BaseOptimizer):
@@ -80,46 +79,15 @@ def validate_parameters(self):
8079
self.validate_weight_decay_ratio(self.wd_ratio)
8180
self.validate_epsilon(self.eps)
8281

83-
@staticmethod
84-
def channel_view(x: torch.Tensor) -> torch.Tensor:
85-
return x.view(x.size()[0], -1)
86-
87-
@staticmethod
88-
def layer_view(x: torch.Tensor) -> torch.Tensor:
89-
return x.view(1, -1)
90-
91-
@staticmethod
92-
def cosine_similarity(
93-
x: torch.Tensor,
94-
y: torch.Tensor,
95-
eps: float,
96-
view_func: Callable[[torch.Tensor], torch.Tensor],
97-
) -> torch.Tensor:
98-
x = view_func(x)
99-
y = view_func(y)
100-
return F.cosine_similarity(x, y, dim=1, eps=eps).abs_()
101-
102-
def projection(
103-
self,
104-
p,
105-
grad,
106-
perturb: torch.Tensor,
107-
delta: float,
108-
wd_ratio: float,
109-
eps: float,
110-
) -> Tuple[torch.Tensor, float]:
111-
wd: float = 1.0
112-
expand_size: List[int] = [-1] + [1] * (len(p.shape) - 1)
113-
for view_func in (self.channel_view, self.layer_view):
114-
cosine_sim = self.cosine_similarity(grad, p, eps, view_func)
115-
116-
if cosine_sim.max() < delta / math.sqrt(view_func(p).size()[1]):
117-
p_n = p / view_func(p).norm(dim=1).view(expand_size).add_(eps)
118-
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
119-
wd = wd_ratio
120-
return perturb, wd
121-
122-
return perturb, wd
82+
@torch.no_grad()
83+
def reset(self):
84+
for group in self.param_groups:
85+
for p in group['params']:
86+
state = self.state[p]
87+
88+
state['step'] = 0
89+
state['exp_avg'] = torch.zeros_like(p)
90+
state['exp_avg_sq'] = torch.zeros_like(p)
12391

12492
@torch.no_grad()
12593
def step(self, closure: CLOSURE = None) -> LOSS:
@@ -166,7 +134,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
166134

167135
wd_ratio: float = 1
168136
if len(p.shape) > 1:
169-
perturb, wd_ratio = self.projection(
137+
perturb, wd_ratio = projection(
170138
p,
171139
grad,
172140
perturb,

pytorch_optimizer/base_optimizer.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from abc import ABC, abstractmethod
22

3+
import torch
4+
35
from pytorch_optimizer.types import BETAS
46

57

@@ -48,8 +50,8 @@ def validate_momentum(momentum: float):
4850

4951
@staticmethod
5052
def validate_lookahead_k(k: int):
51-
if k < 0:
52-
raise ValueError(f'[-] k {k} must be non-negative')
53+
if k < 1:
54+
raise ValueError(f'[-] k {k} must be positive')
5355

5456
@staticmethod
5557
def validate_rho(rho: float):
@@ -61,6 +63,28 @@ def validate_epsilon(epsilon: float):
6163
if epsilon < 0.0:
6264
raise ValueError(f'[-] epsilon {epsilon} must be non-negative')
6365

66+
@staticmethod
67+
def validate_alpha(alpha: float):
68+
if not 0.0 <= alpha < 1.0:
69+
raise ValueError(f'[-] alpha {alpha} must be in the range [0, 1)')
70+
71+
@staticmethod
72+
def validate_pullback_momentum(pullback_momentum: str):
73+
if pullback_momentum not in ('none', 'reset', 'pullback'):
74+
raise ValueError(
75+
f'[-] pullback_momentum {pullback_momentum} must be one of (\'none\' or \'reset\' or \'pullback\')'
76+
)
77+
78+
@staticmethod
79+
def validate_reduction(reduction: str):
80+
if reduction not in ('mean', 'sum'):
81+
raise ValueError(f'[-] reduction {reduction} must be one of (\'mean\' or \'sum\')')
82+
6483
@abstractmethod
6584
def validate_parameters(self):
6685
raise NotImplementedError
86+
87+
@abstractmethod
88+
@torch.no_grad()
89+
def reset(self):
90+
raise NotImplementedError

pytorch_optimizer/diffgrad.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,16 @@ def validate_parameters(self):
5858
self.validate_weight_decay(self.weight_decay)
5959
self.validate_epsilon(self.eps)
6060

61-
def __setstate__(self, state: STATE):
62-
super().__setstate__(state)
61+
@torch.no_grad()
62+
def reset(self):
63+
for group in self.param_groups:
64+
for p in group['params']:
65+
state = self.state[p]
66+
67+
state['step'] = 0
68+
state['exp_avg'] = torch.zeros_like(p)
69+
state['exp_avg_sq'] = torch.zeros_like(p)
70+
state['previous_grad'] = torch.zeros_like(p)
6371

6472
@torch.no_grad()
6573
def step(self, closure: CLOSURE = None) -> LOSS:

pytorch_optimizer/diffrgrad.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,16 @@ def validate_parameters(self):
7171
self.validate_weight_decay(self.weight_decay)
7272
self.validate_epsilon(self.eps)
7373

74-
def __setstate__(self, state: STATE):
75-
super().__setstate__(state)
74+
@torch.no_grad()
75+
def reset(self):
76+
for group in self.param_groups:
77+
for p in group['params']:
78+
state = self.state[p]
79+
80+
state['step'] = 0
81+
state['exp_avg'] = torch.zeros_like(p)
82+
state['exp_avg_sq'] = torch.zeros_like(p)
83+
state['previous_grad'] = torch.zeros_like(p)
7684

7785
@torch.no_grad()
7886
def step(self, closure: CLOSURE = None) -> LOSS:
@@ -123,7 +131,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
123131
dfc = 1.0 / (1.0 + torch.exp(-diff))
124132
state['previous_grad'] = grad.clone()
125133

126-
buffered = group['buffer'][int(state['step'] % 10)]
134+
buffered = group['buffer'][state['step'] % 10]
127135
if state['step'] == buffered[0]:
128136
n_sma, step_size = buffered[1], buffered[2]
129137
else:
@@ -144,10 +152,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
144152
/ (n_sma_max - 2)
145153
)
146154

147-
if group['adamd_debias_term']:
148-
step_size = rt
149-
else:
150-
step_size = rt / bias_correction1
155+
step_size = rt
156+
if not group['adamd_debias_term']:
157+
step_size /= bias_correction1
151158
elif self.degenerated_to_sgd:
152159
step_size = 1.0 / bias_correction1
153160
else:

pytorch_optimizer/fp16.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from typing import Dict, List, Optional, Union
22

33
import torch
4+
from torch import nn
45
from torch.optim import Optimizer
56

6-
from pytorch_optimizer.types import CLOSURE
7+
from pytorch_optimizer.types import CLOSURE, PARAMETERS
78
from pytorch_optimizer.utils import clip_grad_norm, has_overflow
89

910
__AUTHOR__ = 'Facebook'
@@ -114,26 +115,29 @@ def get_parameters(cls, optimizer: Optimizer):
114115
return params
115116

116117
@classmethod
117-
def build_fp32_params(cls, parameters, flatten: bool = True) -> Union[torch.Tensor, List[torch.Tensor]]:
118+
def build_fp32_params(
119+
cls, parameters: PARAMETERS, flatten: bool = True
120+
) -> Union[torch.Tensor, List[torch.Tensor]]:
118121
# create FP32 copy of parameters and grads
119122
if flatten:
120-
total_param_size = sum(p.data.numel() for p in parameters)
123+
total_param_size: int = sum(p.numel() for p in parameters)
121124
fp32_params = torch.zeros(total_param_size, dtype=torch.float, device=parameters[0].device)
122125

123126
offset: int = 0
124127
for p in parameters:
125-
p_num_el = p.data.numel()
126-
fp32_params[offset : offset + p_num_el].copy_(p.data.view(-1))
128+
p_num_el = p.numel()
129+
fp32_params[offset : offset + p_num_el].copy_(p.view(-1))
127130
offset += p_num_el
128131

129-
fp32_params = torch.nn.Parameter(fp32_params)
130-
fp32_params.grad = fp32_params.data.new(total_param_size)
132+
fp32_params = nn.Parameter(fp32_params)
133+
fp32_params.grad = fp32_params.new(total_param_size)
134+
131135
return fp32_params
132136

133137
fp32_params = []
134138
for p in parameters:
135-
p32 = torch.nn.Parameter(p.data.float())
136-
p32.grad = torch.zeros_like(p32.data)
139+
p32 = nn.Parameter(p.float())
140+
p32.grad = torch.zeros_like(p32)
137141
fp32_params.append(p32)
138142

139143
return fp32_params
@@ -181,25 +185,25 @@ def sync_fp16_grads_to_fp32(self, multiply_grads: float = 1.0):
181185
continue
182186

183187
if p.grad is not None:
184-
p32.grad.data.copy_(p.grad.data)
185-
p32.grad.data.mul_(multiply_grads)
188+
p32.grad.copy_(p.grad)
189+
p32.grad.mul_(multiply_grads)
186190
else:
187-
p32.grad = torch.zeros_like(p.data, dtype=torch.float)
191+
p32.grad = torch.zeros_like(p, dtype=torch.float)
188192

189193
self.needs_sync = False
190194

191-
def multiply_grads(self, c):
195+
def multiply_grads(self, c: float):
192196
"""Multiplies grads by a constant c."""
193197
if self.needs_sync:
194198
self.sync_fp16_grads_to_fp32(c)
195199
else:
196200
for p32 in self.fp32_params:
197-
p32.grad.data.mul_(c)
201+
p32.grad.mul_(c)
198202

199203
def update_main_grads(self):
200204
self.sync_fp16_grads_to_fp32()
201205

202-
def clip_main_grads(self, max_norm):
206+
def clip_main_grads(self, max_norm: float):
203207
"""Clips gradient norm and updates dynamic loss scaler."""
204208
self.sync_fp16_grads_to_fp32()
205209

@@ -208,8 +212,10 @@ def clip_main_grads(self, max_norm):
208212
# detect overflow and adjust loss scale
209213
if self.scaler is not None:
210214
overflow: bool = has_overflow(grad_norm)
211-
prev_scale = self.scaler.loss_scale
215+
prev_scale: float = self.scaler.loss_scale
216+
212217
self.scaler.update_scale(overflow)
218+
213219
if overflow:
214220
self.zero_grad()
215221
if self.scaler.loss_scale <= self.min_loss_scale:
@@ -235,7 +241,7 @@ def step(self, closure: CLOSURE = None):
235241
for p, p32 in zip(self.fp16_params, self.fp32_params):
236242
if not p.requires_grad:
237243
continue
238-
p.data.copy_(p32.data)
244+
p.data.copy_(p32)
239245

240246
def zero_grad(self):
241247
"""Clears the gradients of all optimized parameters."""

pytorch_optimizer/lamb.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,16 @@ def validate_parameters(self):
6565
self.validate_weight_decay(self.weight_decay)
6666
self.validate_epsilon(self.eps)
6767

68+
@torch.no_grad()
69+
def reset(self):
70+
for group in self.param_groups:
71+
for p in group['params']:
72+
state = self.state[p]
73+
74+
state['step'] = 0
75+
state['exp_avg'] = torch.zeros_like(p)
76+
state['exp_avg_sq'] = torch.zeros_like(p)
77+
6878
def get_gradient_norm(self) -> float:
6979
norm_sq: float = 0.0
7080
for group in self.param_groups:

0 commit comments

Comments
 (0)