|
13 | 13 | Lookahead, |
14 | 14 | PCGrad, |
15 | 15 | ProportionScheduler, |
| 16 | + gradfilter_ema, |
| 17 | + gradfilter_ma, |
16 | 18 | load_optimizer, |
17 | 19 | ) |
18 | 20 | from pytorch_optimizer.base.exception import NoClosureError, ZeroParameterSizeError |
@@ -608,3 +610,33 @@ def test_schedule_free_train_mode(): |
608 | 610 | opt.reset() |
609 | 611 | opt.eval() |
610 | 612 | opt.train() |
| 613 | + |
| 614 | + |
| 615 | +@pytest.mark.parametrize('filter_type', ['mean', 'sum']) |
| 616 | +def test_grokfast_ma(filter_type, environment): |
| 617 | + _, model, _ = environment |
| 618 | + |
| 619 | + model.fc1.weight.grad = torch.randn(2, 2) |
| 620 | + model.fc1.bias.grad = torch.randn(2) |
| 621 | + model.fc2.weight.grad = torch.randn(1, 2) |
| 622 | + model.fc2.bias.grad = torch.randn(1) |
| 623 | + |
| 624 | + _ = gradfilter_ma(model, None, window_size=1, filter_type=filter_type, warmup=False) |
| 625 | + |
| 626 | + |
| 627 | +def test_grokfast_ma_invalid(environment): |
| 628 | + _, model, _ = environment |
| 629 | + |
| 630 | + with pytest.raises(ValueError): |
| 631 | + _ = gradfilter_ma(model, None, window_size=1, filter_type='asdf', warmup=False) |
| 632 | + |
| 633 | + |
| 634 | +def test_grokfast_ema(environment): |
| 635 | + _, model, _ = environment |
| 636 | + |
| 637 | + model.fc1.weight.grad = torch.randn(2, 2) |
| 638 | + model.fc1.bias.grad = torch.randn(2) |
| 639 | + model.fc2.weight.grad = torch.randn(1, 2) |
| 640 | + model.fc2.bias.grad = torch.randn(1) |
| 641 | + |
| 642 | + _ = gradfilter_ema(model, None) |
0 commit comments