Skip to content

Commit 0179cbf

Browse files
committed
feature: tweak AdaFactor optimizer
1 parent 4a095ae commit 0179cbf

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

pytorch_optimizer/optimizer/adafactor.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010

1111

1212
class AdaFactor(Optimizer, BaseOptimizer):
13-
r"""Adaptive Learning Rates with Sublinear Memory Cost.
13+
r"""Adaptive Learning Rates with Sublinear Memory Cost with some tweaks.
1414
1515
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
1616
:param lr: float. learning rate.
17-
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
17+
:param betas: Union[BETAS, None]. coefficients used for computing running averages of gradient and the squared
18+
hessian trace. if betas is None, first momentum will be skipped.
1819
:param decay_rate: float. coefficient used to compute running averages of square gradient.
1920
:param weight_decay: float. weight decay (L2 penalty).
2021
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
@@ -27,6 +28,9 @@ class AdaFactor(Optimizer, BaseOptimizer):
2728
is being used.
2829
:param eps1: float. term added to the denominator to improve numerical stability.
2930
:param eps2: float. term added to the denominator to improve numerical stability.
31+
:param momentum_dtype: torch.dtype. type of momentum variable. In VIT paper observed that storing momentum in
32+
half-precision (bfloat16 type) does not affect training dynamics and has no effect on the outcome while
33+
reducing optimize overhead from 2-fold to 1.5-fold.
3034
"""
3135

3236
def __init__(
@@ -45,6 +49,7 @@ def __init__(
4549
warmup_init: bool = False,
4650
eps1: float = 1e-30,
4751
eps2: float = 1e-3,
52+
momentum_dtype: torch.dtype = torch.bfloat16,
4853
):
4954
self.validate_learning_rate(lr)
5055
self.validate_betas(betas)
@@ -56,6 +61,7 @@ def __init__(
5661
self.clip_threshold = clip_threshold
5762
self.eps1 = eps1
5863
self.eps2 = eps2
64+
self.momentum_dtype = momentum_dtype
5965

6066
defaults: DEFAULTS = {
6167
'lr': lr,
@@ -87,7 +93,8 @@ def reset(self):
8793
grad_shape: Tuple[int, ...] = grad.shape
8894
factored: bool = self.get_options(grad_shape)
8995

90-
state['exp_avg'] = torch.zeros_like(p)
96+
if group['betas'][0] is not None:
97+
state['exp_avg'] = torch.zeros_like(p, dtype=self.momentum_dtype)
9198

9299
if factored:
93100
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
@@ -149,7 +156,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
149156
else:
150157
group['step'] = 1
151158

152-
beta1, _ = group['betas']
159+
beta1, beta2 = group['betas']
153160

154161
beta2_t: float = 1.0 - math.pow(group['step'], self.decay_rate)
155162

@@ -167,7 +174,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
167174
factored: bool = self.get_options(grad_shape)
168175

169176
if len(state) == 0:
170-
state['exp_avg'] = torch.zeros_like(p)
177+
if beta1 is not None:
178+
state['exp_avg'] = torch.zeros_like(p, dtype=self.momentum_dtype)
171179

172180
if factored:
173181
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
@@ -205,6 +213,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
205213
else:
206214
exp_avg_sq = state['exp_avg_sq']
207215
exp_avg_sq.mul_(beta2_t).add_(update, alpha=1.0 - beta2_t)
216+
exp_avg_sq.clamp_(max=beta2)
217+
208218
torch.rsqrt(exp_avg_sq, out=update)
209219

210220
if group['ams_bound']:
@@ -216,8 +226,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
216226

217227
update.div_((self.get_rms(update) / self.clip_threshold).clamp_(min=1.0)).mul_(lr)
218228

219-
exp_avg = state['exp_avg']
220-
exp_avg.mul_(beta1).add_(update, alpha=1.0 - beta1)
229+
if beta1 is not None:
230+
exp_avg = state['exp_avg']
231+
exp_avg.mul_(beta1).add_(update, alpha=1.0 - beta1)
232+
233+
update = exp_avg
221234

222235
self.apply_weight_decay(
223236
p=p,
@@ -228,6 +241,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
228241
fixed_decay=group['fixed_decay'],
229242
)
230243

231-
p.add_(-exp_avg)
244+
p.add_(-update)
232245

233246
return loss

0 commit comments

Comments
 (0)