11from typing import List , Optional , Tuple , Union
22
3- import numpy as np
43import torch
54from torch import Tensor
65from torch .optim import Optimizer
@@ -10,38 +9,36 @@ def _get_scalar_dtype():
109 """Get the scalar dtype that the optimizer uses for state"""
1110 return torch .float64
1211
12+
1313def _factored_dims (
14- shape : Tuple [int , ...],
15- factored : bool ,
16- min_dim_size_to_factor : int
14+ shape : Tuple [int , ...],
15+ factored : bool ,
16+ min_dim_size_to_factor : int
1717) -> Optional [tuple [int , int ]]:
18- """Whether to use a factored second moment estimator.
18+ """Whether to use a factored second moment estimator.
1919
20- This function returns a tuple with the two largest axes to reduce over.
21- If no two dimensions have size >= min_dim_size_to_factor, return None.
20+ This function returns a tuple with the two largest axes to reduce over.
21+ If no two dimensions have size >= min_dim_size_to_factor, return None.
2222
23- Args:
24- shape: an input shape
25- factored: whether to use factored second-moment estimator for > 2d vars.
26- min_dim_size_to_factor: only factor accumulator if two array dimensions have at least this size.
23+ Args:
24+ shape: an input shape
25+ factored: whether to use factored second-moment estimator for > 2d vars.
26+ min_dim_size_to_factor: only factor accumulator if two array dimensions have at least this size.
2727
28- Returns:
29- None or a tuple of ints
30- """
31- if not factored or len (shape ) < 2 :
32- return None
33- sorted_dims = np . argsort ( shape )
34- if shape [sorted_dims [- 2 ]] < min_dim_size_to_factor :
35- return None
36- return int (sorted_dims [- 2 ]), int (sorted_dims [- 1 ])
28+ Returns:
29+ None or a tuple of ints
30+ """
31+ if not factored or len (shape ) < 2 :
32+ return None
33+ sorted_dims = sorted ((( x , i ) for i , x in enumerate ( shape )) )
34+ if shape [sorted_dims [- 2 ][ 1 ]] < min_dim_size_to_factor :
35+ return None
36+ return int (sorted_dims [- 2 ][ 1 ] ), int (sorted_dims [- 1 ][ 1 ])
3737
3838
3939class AdafactorBigVision (Optimizer ):
4040 """
4141 PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations.
42-
43-
44-
4542 """
4643
4744 def __init__ (
@@ -95,6 +92,12 @@ def __setstate__(self, state):
9592 if len (p_state ) != 0 and not torch .is_tensor (p_state ['step' ]):
9693 p_state ['step' ] = torch .tensor (float (p_state ['step' ]), dtype = _get_scalar_dtype ())
9794
95+ if 'exp_avg' in p_state and torch .is_tensor (p_state ['exp_avg' ]):
96+ # FIXME this is a bit of a hack, optimizer.load_state_dict appears to upcast
97+ # the momentum to float32 (it's half precision in the state_dict), need to
98+ # look into this further. Better to override _process_value_according_to_param_policy?
99+ p_state ['exp_avg' ] = p_state ['exp_avg' ].to (dtype = self .defaults ['momentum_dtype' ])
100+
98101 @torch .no_grad ()
99102 def step (self , closure = None ):
100103 loss = None
@@ -181,6 +184,7 @@ def step(self, closure=None):
181184
182185 return loss
183186
187+
184188def _single_tensor_adafactor (
185189 params : List [Tensor ],
186190 grads : List [Tensor ],
@@ -262,24 +266,25 @@ def _single_tensor_adafactor(
262266 # Update parameters
263267 param .add_ (update , alpha = - 1.0 )
264268
269+
265270def _multi_tensor_adafactor (
266- params : List [Tensor ],
267- grads : List [Tensor ],
268- exp_avg_sq_rs : List [Optional [Tensor ]],
269- exp_avg_sq_cs : List [Optional [Tensor ]],
270- exp_avg_sqs : List [Optional [Tensor ]],
271- exp_avgs : List [Optional [Tensor ]],
272- state_steps : List [Tensor ],
273- * ,
274- beta2_decay : float ,
275- beta2_cap : float ,
276- min_dim_size_to_factor : int ,
277- eps : float ,
278- lr : float ,
279- weight_decay : float ,
280- momentum : Optional [float ],
281- momentum_dtype : Union [str , torch .dtype ],
282- clipping_threshold : Optional [float ],
283- unscaled_wd : bool ,
271+ params : List [Tensor ],
272+ grads : List [Tensor ],
273+ exp_avg_sq_rs : List [Optional [Tensor ]],
274+ exp_avg_sq_cs : List [Optional [Tensor ]],
275+ exp_avg_sqs : List [Optional [Tensor ]],
276+ exp_avgs : List [Optional [Tensor ]],
277+ state_steps : List [Tensor ],
278+ * ,
279+ beta2_decay : float ,
280+ beta2_cap : float ,
281+ min_dim_size_to_factor : int ,
282+ eps : float ,
283+ lr : float ,
284+ weight_decay : float ,
285+ momentum : Optional [float ],
286+ momentum_dtype : Union [str , torch .dtype ],
287+ clipping_threshold : Optional [float ],
288+ unscaled_wd : bool ,
284289):
285290 assert False , 'multi-tensor fn (foreach=True) not implemented yet'
0 commit comments