Skip to content

Commit 548fdb5

Browse files
committed
Remove adafactorbv numpy dep, hack fix for loading optimizer state w/ half prec momentum (need better one)
1 parent 91f0ea3 commit 548fdb5

File tree

1 file changed

+46
-41
lines changed

1 file changed

+46
-41
lines changed

timm/optim/adafactor_bv.py

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import List, Optional, Tuple, Union
22

3-
import numpy as np
43
import torch
54
from torch import Tensor
65
from torch.optim import Optimizer
@@ -10,38 +9,36 @@ def _get_scalar_dtype():
109
"""Get the scalar dtype that the optimizer uses for state"""
1110
return torch.float64
1211

12+
1313
def _factored_dims(
14-
shape: Tuple[int, ...],
15-
factored: bool,
16-
min_dim_size_to_factor: int
14+
shape: Tuple[int, ...],
15+
factored: bool,
16+
min_dim_size_to_factor: int
1717
) -> Optional[tuple[int, int]]:
18-
"""Whether to use a factored second moment estimator.
18+
"""Whether to use a factored second moment estimator.
1919
20-
This function returns a tuple with the two largest axes to reduce over.
21-
If no two dimensions have size >= min_dim_size_to_factor, return None.
20+
This function returns a tuple with the two largest axes to reduce over.
21+
If no two dimensions have size >= min_dim_size_to_factor, return None.
2222
23-
Args:
24-
shape: an input shape
25-
factored: whether to use factored second-moment estimator for > 2d vars.
26-
min_dim_size_to_factor: only factor accumulator if two array dimensions have at least this size.
23+
Args:
24+
shape: an input shape
25+
factored: whether to use factored second-moment estimator for > 2d vars.
26+
min_dim_size_to_factor: only factor accumulator if two array dimensions have at least this size.
2727
28-
Returns:
29-
None or a tuple of ints
30-
"""
31-
if not factored or len(shape) < 2:
32-
return None
33-
sorted_dims = np.argsort(shape)
34-
if shape[sorted_dims[-2]] < min_dim_size_to_factor:
35-
return None
36-
return int(sorted_dims[-2]), int(sorted_dims[-1])
28+
Returns:
29+
None or a tuple of ints
30+
"""
31+
if not factored or len(shape) < 2:
32+
return None
33+
sorted_dims = sorted(((x, i) for i, x in enumerate(shape)))
34+
if shape[sorted_dims[-2][1]] < min_dim_size_to_factor:
35+
return None
36+
return int(sorted_dims[-2][1]), int(sorted_dims[-1][1])
3737

3838

3939
class AdafactorBigVision(Optimizer):
4040
"""
4141
PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations.
42-
43-
44-
4542
"""
4643

4744
def __init__(
@@ -95,6 +92,12 @@ def __setstate__(self, state):
9592
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
9693
p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype())
9794

95+
if 'exp_avg' in p_state and torch.is_tensor(p_state['exp_avg']):
96+
# FIXME this is a bit of a hack, optimizer.load_state_dict appears to upcast
97+
# the momentum to float32 (it's half precision in the state_dict), need to
98+
# look into this further. Better to override _process_value_according_to_param_policy?
99+
p_state['exp_avg'] = p_state['exp_avg'].to(dtype=self.defaults['momentum_dtype'])
100+
98101
@torch.no_grad()
99102
def step(self, closure=None):
100103
loss = None
@@ -181,6 +184,7 @@ def step(self, closure=None):
181184

182185
return loss
183186

187+
184188
def _single_tensor_adafactor(
185189
params: List[Tensor],
186190
grads: List[Tensor],
@@ -262,24 +266,25 @@ def _single_tensor_adafactor(
262266
# Update parameters
263267
param.add_(update, alpha=-1.0)
264268

269+
265270
def _multi_tensor_adafactor(
266-
params: List[Tensor],
267-
grads: List[Tensor],
268-
exp_avg_sq_rs: List[Optional[Tensor]],
269-
exp_avg_sq_cs: List[Optional[Tensor]],
270-
exp_avg_sqs: List[Optional[Tensor]],
271-
exp_avgs: List[Optional[Tensor]],
272-
state_steps: List[Tensor],
273-
*,
274-
beta2_decay: float,
275-
beta2_cap: float,
276-
min_dim_size_to_factor: int,
277-
eps: float,
278-
lr: float,
279-
weight_decay: float,
280-
momentum: Optional[float],
281-
momentum_dtype: Union[str, torch.dtype],
282-
clipping_threshold: Optional[float],
283-
unscaled_wd: bool,
271+
params: List[Tensor],
272+
grads: List[Tensor],
273+
exp_avg_sq_rs: List[Optional[Tensor]],
274+
exp_avg_sq_cs: List[Optional[Tensor]],
275+
exp_avg_sqs: List[Optional[Tensor]],
276+
exp_avgs: List[Optional[Tensor]],
277+
state_steps: List[Tensor],
278+
*,
279+
beta2_decay: float,
280+
beta2_cap: float,
281+
min_dim_size_to_factor: int,
282+
eps: float,
283+
lr: float,
284+
weight_decay: float,
285+
momentum: Optional[float],
286+
momentum_dtype: Union[str, torch.dtype],
287+
clipping_threshold: Optional[float],
288+
unscaled_wd: bool,
284289
):
285290
assert False, 'multi-tensor fn (foreach=True) not implemented yet'

0 commit comments

Comments
 (0)