Skip to content

Commit d8853d1

Browse files
committed
update: drop GaLore variant support
1 parent f3c6476 commit d8853d1

File tree

1 file changed

+0
-18
lines changed

1 file changed

+0
-18
lines changed

pytorch_optimizer/optimizer/adafactor.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pytorch_optimizer.base.exception import NoSparseGradientError
88
from pytorch_optimizer.base.optimizer import BaseOptimizer
99
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
10-
from pytorch_optimizer.optimizer.galore import GaLoreProjector
1110

1211

1312
class AdaFactor(Optimizer, BaseOptimizer):
@@ -28,7 +27,6 @@ class AdaFactor(Optimizer, BaseOptimizer):
2827
is being used.
2928
:param eps1: float. term added to the denominator to improve numerical stability.
3029
:param eps2: float. term added to the denominator to improve numerical stability.
31-
:param use_galore: bool. use GaLore variant.
3230
"""
3331

3432
def __init__(
@@ -47,7 +45,6 @@ def __init__(
4745
warmup_init: bool = False,
4846
eps1: float = 1e-30,
4947
eps2: float = 1e-3,
50-
**kwargs,
5148
):
5249
self.validate_learning_rate(lr)
5350
self.validate_betas(betas)
@@ -72,7 +69,6 @@ def __init__(
7269
'warmup_init': warmup_init,
7370
'eps1': eps1,
7471
'eps2': eps2,
75-
**kwargs,
7672
}
7773
super().__init__(params, defaults)
7874

@@ -170,17 +166,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
170166
grad_shape: Tuple[int, ...] = grad.shape
171167
factored: bool = self.get_options(grad_shape)
172168

173-
if 'rank' in group:
174-
if 'projector' not in state:
175-
state['projector'] = GaLoreProjector(
176-
rank=group['rank'],
177-
update_proj_gap=group['update_proj_gap'],
178-
scale=group['scale'],
179-
projection_type=group['projection_type'],
180-
)
181-
182-
grad = state['projector'].project(grad, group['step'])
183-
184169
if len(state) == 0:
185170
state['exp_avg'] = torch.zeros_like(p)
186171

@@ -234,9 +219,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
234219
exp_avg = state['exp_avg']
235220
exp_avg.mul_(beta1).add_(update, alpha=1.0 - beta1)
236221

237-
if 'rank' in group:
238-
exp_avg = state['projector'].project_back(exp_avg)
239-
240222
self.apply_weight_decay(
241223
p=p,
242224
grad=None,

0 commit comments

Comments
 (0)