Skip to content

Commit 15d52f6

Browse files
authored
Merge pull request #243 from kozistr/update/adafactor-optimizer
[Feature] Tweak AdaFactor optimizer
2 parents 4a095ae + 5d924c5 commit 15d52f6

File tree

5 files changed

+41
-21
lines changed

5 files changed

+41
-21
lines changed

docs/changelogs/v3.0.1.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
* Implement `FAdam` optimizer. (#241, #242)
66
* [Adam is a natural gradient optimizer using diagonal empirical Fisher information](https://arxiv.org/abs/2405.12807)
7+
* Tweak `AdaFactor` optimizer. (#236, #243)
8+
* support not-using-first-momentum when beta1 is not given
9+
* default dtype for first momentum to `bfloat16`
10+
* clip second momentum to 0.999
711

812
### Bug
913

pytorch_optimizer/base/optimizer.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def get_adanorm_gradient(
215215
return grad * exp_grad_norm / grad_norm if exp_grad_norm > grad_norm else grad
216216

217217
@staticmethod
218-
def validate_range(x: float, name: str, low: float, high: float, range_type: str = '[)'):
218+
def validate_range(x: float, name: str, low: float, high: float, range_type: str = '[)') -> None:
219219
if range_type == '[)' and not low <= x < high:
220220
raise ValueError(f'[-] {name} must be in the range [{low}, {high})')
221221
if range_type == '[]' and not low <= x <= high:
@@ -226,40 +226,42 @@ def validate_range(x: float, name: str, low: float, high: float, range_type: str
226226
raise ValueError(f'[-] {name} must be in the range ({low}, {high})')
227227

228228
@staticmethod
229-
def validate_non_negative(x: Optional[float], name: str):
229+
def validate_non_negative(x: Optional[float], name: str) -> None:
230230
if x is not None and x < 0.0:
231231
raise ValueError(f'[-] {name} must be non-negative')
232232

233233
@staticmethod
234-
def validate_positive(x: Union[float, int], name: str):
234+
def validate_positive(x: Union[float, int], name: str) -> None:
235235
if x <= 0:
236236
raise ValueError(f'[-] {name} must be positive')
237237

238238
@staticmethod
239-
def validate_boundary(constant: float, boundary: float, bound_type: str = 'upper'):
239+
def validate_boundary(constant: float, boundary: float, bound_type: str = 'upper') -> None:
240240
if bound_type == 'upper' and constant > boundary:
241241
raise ValueError(f'[-] constant {constant} must be in a range of (-inf, {boundary}]')
242242
if bound_type == 'lower' and constant < boundary:
243243
raise ValueError(f'[-] constant {constant} must be in a range of [{boundary}, inf)')
244244

245245
@staticmethod
246-
def validate_step(step: int, step_type: str):
246+
def validate_step(step: int, step_type: str) -> None:
247247
if step < 1:
248248
raise NegativeStepError(step, step_type=step_type)
249249

250250
@staticmethod
251-
def validate_options(x: str, name: str, options: List[str]):
251+
def validate_options(x: str, name: str, options: List[str]) -> None:
252252
if x not in options:
253253
opts: str = ' or '.join([f'\'{option}\'' for option in options]).strip()
254254
raise ValueError(f'[-] {name} {x} must be one of ({opts})')
255255

256256
@staticmethod
257-
def validate_learning_rate(learning_rate: Optional[float]):
257+
def validate_learning_rate(learning_rate: Optional[float]) -> None:
258258
if learning_rate is not None and learning_rate < 0.0:
259259
raise NegativeLRError(learning_rate)
260260

261-
def validate_betas(self, betas: BETAS):
262-
self.validate_range(betas[0], 'beta1', 0.0, 1.0, range_type='[]')
261+
def validate_betas(self, betas: BETAS) -> None:
262+
if betas[0] is not None:
263+
self.validate_range(betas[0], 'beta1', 0.0, 1.0, range_type='[]')
264+
263265
self.validate_range(betas[1], 'beta2', 0.0, 1.0, range_type='[]')
264266

265267
if len(betas) < 3:
@@ -268,7 +270,7 @@ def validate_betas(self, betas: BETAS):
268270
if betas[2] is not None:
269271
self.validate_range(betas[2], 'beta3', 0.0, 1.0, range_type='[]')
270272

271-
def validate_nus(self, nus: Union[float, Tuple[float, float]]):
273+
def validate_nus(self, nus: Union[float, Tuple[float, float]]) -> None:
272274
if isinstance(nus, float):
273275
self.validate_range(nus, 'nu', 0.0, 1.0, range_type='[]')
274276
else:

pytorch_optimizer/base/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
CLOSURE = Optional[Callable[[], float]]
88
LOSS = Optional[float]
9-
BETAS = Union[Tuple[float, float], Tuple[float, float, float]]
9+
BETAS = Union[Tuple[float, float], Tuple[float, float, float], Tuple[None, float]]
1010
DEFAULTS = Dict
1111
PARAMETERS = Optional[Union[Iterable[Dict], Iterable[torch.Tensor]]]
1212
STATE = Dict

pytorch_optimizer/optimizer/adafactor.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010

1111

1212
class AdaFactor(Optimizer, BaseOptimizer):
13-
r"""Adaptive Learning Rates with Sublinear Memory Cost.
13+
r"""Adaptive Learning Rates with Sublinear Memory Cost with some tweaks.
1414
1515
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
1616
:param lr: float. learning rate.
17-
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
17+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared
18+
hessian trace. if beta1 is None, first momentum will be skipped.
1819
:param decay_rate: float. coefficient used to compute running averages of square gradient.
1920
:param weight_decay: float. weight decay (L2 penalty).
2021
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
@@ -27,6 +28,9 @@ class AdaFactor(Optimizer, BaseOptimizer):
2728
is being used.
2829
:param eps1: float. term added to the denominator to improve numerical stability.
2930
:param eps2: float. term added to the denominator to improve numerical stability.
31+
:param momentum_dtype: torch.dtype. type of momentum variable. In VIT paper observed that storing momentum in
32+
half-precision (bfloat16 type) does not affect training dynamics and has no effect on the outcome while
33+
reducing optimize overhead from 2-fold to 1.5-fold.
3034
"""
3135

3236
def __init__(
@@ -45,6 +49,7 @@ def __init__(
4549
warmup_init: bool = False,
4650
eps1: float = 1e-30,
4751
eps2: float = 1e-3,
52+
momentum_dtype: torch.dtype = torch.bfloat16,
4853
):
4954
self.validate_learning_rate(lr)
5055
self.validate_betas(betas)
@@ -56,6 +61,7 @@ def __init__(
5661
self.clip_threshold = clip_threshold
5762
self.eps1 = eps1
5863
self.eps2 = eps2
64+
self.momentum_dtype = momentum_dtype
5965

6066
defaults: DEFAULTS = {
6167
'lr': lr,
@@ -87,7 +93,8 @@ def reset(self):
8793
grad_shape: Tuple[int, ...] = grad.shape
8894
factored: bool = self.get_options(grad_shape)
8995

90-
state['exp_avg'] = torch.zeros_like(p)
96+
if group['betas'][0] is not None:
97+
state['exp_avg'] = torch.zeros_like(p, dtype=self.momentum_dtype)
9198

9299
if factored:
93100
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
@@ -149,7 +156,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
149156
else:
150157
group['step'] = 1
151158

152-
beta1, _ = group['betas']
159+
beta1, beta2 = group['betas']
153160

154161
beta2_t: float = 1.0 - math.pow(group['step'], self.decay_rate)
155162

@@ -167,7 +174,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
167174
factored: bool = self.get_options(grad_shape)
168175

169176
if len(state) == 0:
170-
state['exp_avg'] = torch.zeros_like(p)
177+
if beta1 is not None:
178+
state['exp_avg'] = torch.zeros_like(p, dtype=self.momentum_dtype)
171179

172180
if factored:
173181
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
@@ -205,6 +213,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
205213
else:
206214
exp_avg_sq = state['exp_avg_sq']
207215
exp_avg_sq.mul_(beta2_t).add_(update, alpha=1.0 - beta2_t)
216+
exp_avg_sq.clamp_(max=beta2)
217+
208218
torch.rsqrt(exp_avg_sq, out=update)
209219

210220
if group['ams_bound']:
@@ -216,8 +226,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
216226

217227
update.div_((self.get_rms(update) / self.clip_threshold).clamp_(min=1.0)).mul_(lr)
218228

219-
exp_avg = state['exp_avg']
220-
exp_avg.mul_(beta1).add_(update, alpha=1.0 - beta1)
229+
if beta1 is not None:
230+
exp_avg = state['exp_avg']
231+
exp_avg.mul_(beta1).add_(update, alpha=1.0 - beta1)
232+
233+
update = exp_avg
221234

222235
self.apply_weight_decay(
223236
p=p,
@@ -228,6 +241,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
228241
fixed_decay=group['fixed_decay'],
229242
)
230243

231-
p.add_(-exp_avg)
244+
p.add_(-update)
232245

233246
return loss

tests/constants.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,9 @@
348348
(DAdaptLion, {'lr': 3e0, 'weight_decay': 1e-3}, 20),
349349
(AdamS, {'lr': 1e0, 'weight_decay': 1e-3}, 10),
350350
(AdamS, {'lr': 1e0, 'weight_decay': 1e-3, 'ams_bound': True}, 20),
351-
(AdaFactor, {'lr': 7.5e-1, 'weight_decay': 1e-3, 'scale_parameter': False}, 100),
352-
(AdaFactor, {'lr': 7.5e-1, 'weight_decay': 1e-3, 'ams_bound': True}, 125),
351+
(AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'scale_parameter': False}, 100),
352+
(AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'ams_bound': True}, 120),
353+
(AdaFactor, {'lr': 1e1, 'betas': (None, 0.999), 'weight_decay': 1e-3}, 40),
353354
(Apollo, {'lr': 5e-1, 'weight_decay': 1e-3}, 10),
354355
(Apollo, {'lr': 5e-1, 'weight_decay': 1e-3, 'rebound': 'belief'}, 10),
355356
(Apollo, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decay_type': 'stable', 'warmup_steps': 0}, 50),

0 commit comments

Comments
 (0)