77from pytorch_optimizer .base .exception import NoSparseGradientError
88from pytorch_optimizer .base .optimizer import BaseOptimizer
99from pytorch_optimizer .base .types import BETAS , CLOSURE , DEFAULTS , LOSS , PARAMETERS
10- from pytorch_optimizer .optimizer .galore import GaLoreProjector
1110
1211
1312class AdaFactor (Optimizer , BaseOptimizer ):
@@ -28,7 +27,6 @@ class AdaFactor(Optimizer, BaseOptimizer):
2827 is being used.
2928 :param eps1: float. term added to the denominator to improve numerical stability.
3029 :param eps2: float. term added to the denominator to improve numerical stability.
31- :param use_galore: bool. use GaLore variant.
3230 """
3331
3432 def __init__ (
@@ -47,7 +45,6 @@ def __init__(
4745 warmup_init : bool = False ,
4846 eps1 : float = 1e-30 ,
4947 eps2 : float = 1e-3 ,
50- ** kwargs ,
5148 ):
5249 self .validate_learning_rate (lr )
5350 self .validate_betas (betas )
@@ -72,7 +69,6 @@ def __init__(
7269 'warmup_init' : warmup_init ,
7370 'eps1' : eps1 ,
7471 'eps2' : eps2 ,
75- ** kwargs ,
7672 }
7773 super ().__init__ (params , defaults )
7874
@@ -170,17 +166,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
170166 grad_shape : Tuple [int , ...] = grad .shape
171167 factored : bool = self .get_options (grad_shape )
172168
173- if 'rank' in group :
174- if 'projector' not in state :
175- state ['projector' ] = GaLoreProjector (
176- rank = group ['rank' ],
177- update_proj_gap = group ['update_proj_gap' ],
178- scale = group ['scale' ],
179- projection_type = group ['projection_type' ],
180- )
181-
182- grad = state ['projector' ].project (grad , group ['step' ])
183-
184169 if len (state ) == 0 :
185170 state ['exp_avg' ] = torch .zeros_like (p )
186171
@@ -234,9 +219,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
234219 exp_avg = state ['exp_avg' ]
235220 exp_avg .mul_ (beta1 ).add_ (update , alpha = 1.0 - beta1 )
236221
237- if 'rank' in group :
238- exp_avg = state ['projector' ].project_back (exp_avg )
239-
240222 self .apply_weight_decay (
241223 p = p ,
242224 grad = None ,
0 commit comments