Skip to content

Commit 88e75b8

Browse files
committed
update: GrokFast optimizer
1 parent fca320d commit 88e75b8

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

pytorch_optimizer/optimizer/grokfast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def gradfilter_ma(
4646
elif filter_type == 'sum':
4747
avg = sum(grads[n])
4848
else:
49-
raise ValueError(f'Unrecognized filter_type {filter_type}')
49+
raise ValueError(f'not supported filter_type {filter_type}')
5050

5151
p.grad.add_(avg, alpha=lamb)
5252

tests/test_optimizers.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
Lookahead,
1414
PCGrad,
1515
ProportionScheduler,
16+
gradfilter_ema,
17+
gradfilter_ma,
1618
load_optimizer,
1719
)
1820
from pytorch_optimizer.base.exception import NoClosureError, ZeroParameterSizeError
@@ -608,3 +610,33 @@ def test_schedule_free_train_mode():
608610
opt.reset()
609611
opt.eval()
610612
opt.train()
613+
614+
615+
@pytest.mark.parametrize('filter_type', ['mean', 'sum'])
616+
def test_grokfast_ma(filter_type, environment):
617+
_, model, _ = environment
618+
619+
model.fc1.weight.grad = torch.randn(2, 2)
620+
model.fc1.bias.grad = torch.randn(2)
621+
model.fc2.weight.grad = torch.randn(1, 2)
622+
model.fc2.bias.grad = torch.randn(1)
623+
624+
_ = gradfilter_ma(model, None, window_size=1, filter_type=filter_type, warmup=False)
625+
626+
627+
def test_grokfast_ma_invalid(environment):
628+
_, model, _ = environment
629+
630+
with pytest.raises(ValueError):
631+
_ = gradfilter_ma(model, None, window_size=1, filter_type='asdf', warmup=False)
632+
633+
634+
def test_grokfast_ema(environment):
635+
_, model, _ = environment
636+
637+
model.fc1.weight.grad = torch.randn(2, 2)
638+
model.fc1.bias.grad = torch.randn(2)
639+
model.fc2.weight.grad = torch.randn(1, 2)
640+
model.fc2.bias.grad = torch.randn(1)
641+
642+
_ = gradfilter_ema(model, None)

0 commit comments

Comments
 (0)