Skip to content

Commit fca320d

Browse files
committed
update: GrokFast optimizer
1 parent 4609df4 commit fca320d

File tree

3 files changed

+6
-1
lines changed

3 files changed

+6
-1
lines changed

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from pytorch_optimizer.optimizer.galore import GaLore, GaLoreProjector
6161
from pytorch_optimizer.optimizer.gc import centralize_gradient
6262
from pytorch_optimizer.optimizer.gravity import Gravity
63+
from pytorch_optimizer.optimizer.grokfast import GrokFastAdamW, gradfilter_ema, gradfilter_ma
6364
from pytorch_optimizer.optimizer.lamb import Lamb
6465
from pytorch_optimizer.optimizer.lars import LARS
6566
from pytorch_optimizer.optimizer.lion import Lion
@@ -192,6 +193,7 @@
192193
ScheduleFreeSGD,
193194
ScheduleFreeAdamW,
194195
FAdam,
196+
GrokFastAdamW,
195197
]
196198
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
197199

tests/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
Fromage,
5050
GaLore,
5151
Gravity,
52+
GrokFastAdamW,
5253
Lamb,
5354
Lion,
5455
Nero,
@@ -129,6 +130,7 @@
129130
'bsam',
130131
'schedulefreeadamw',
131132
'fadam',
133+
'grokfastadamw',
132134
]
133135

134136
VALID_LR_SCHEDULER_NAMES: List[str] = [
@@ -448,6 +450,7 @@
448450
(ScheduleFreeSGD, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
449451
(ScheduleFreeAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
450452
(FAdam, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
453+
(GrokFastAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 10),
451454
]
452455
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
453456
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),

tests/test_load_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_load_lr_scheduler_invalid(invalid_lr_scheduler_names):
3838

3939

4040
def test_get_supported_optimizers():
41-
assert len(get_supported_optimizers()) == 67
41+
assert len(get_supported_optimizers()) == 68
4242

4343

4444
def test_get_supported_lr_schedulers():

0 commit comments

Comments
 (0)