|
28 | 28 | PULLBACK_MOMENTUM, |
29 | 29 | ) |
30 | 30 | from tests.utils import ( |
| 31 | + Example, |
31 | 32 | MultiHeadLogisticRegression, |
32 | 33 | build_environment, |
33 | 34 | dummy_closure, |
@@ -578,22 +579,28 @@ def test_lomo_fused_backward(optimizer_name, environment): |
578 | 579 |
|
579 | 580 | @pytest.mark.parametrize('optimizer_name', ['lomo', 'adalomo']) |
580 | 581 | @pytest.mark.parametrize('precision', [16, 32]) |
581 | | -def test_lomo_optimizer(optimizer_name, precision, environment): |
582 | | - _, model, _ = environment |
| 582 | +def test_lomo_optimizer(optimizer_name, precision): |
| 583 | + model = Example() |
583 | 584 |
|
584 | | - model.fc1.bias.data = torch.randn(2, dtype=torch.float32) |
585 | | - model.fc1.bias.grad = torch.zeros(2, dtype=torch.float32) |
| 585 | + model.fc1.bias.data = torch.randn(1, dtype=torch.float32) |
| 586 | + model.fc1.bias.grad = torch.randn(1, dtype=torch.float32) |
586 | 587 |
|
587 | 588 | if precision == 16: |
588 | | - model.fc1.weight.data = torch.randn(2, 2, dtype=torch.float16) |
589 | | - model.fc1.weight.grad = torch.zeros(2, 2, dtype=torch.float16) |
| 589 | + model.fc1.weight.data = torch.randn(1, 1, dtype=torch.float16) |
| 590 | + model.fc1.weight.grad = torch.zeros(1, 1, dtype=torch.float16) |
590 | 591 |
|
591 | 592 | optimizer = load_optimizer(optimizer_name)(model, clip_grad_norm=1.0, clip_grad_value=1.0) |
592 | 593 |
|
593 | 594 | if precision == 16: |
594 | 595 | optimizer.clip_coef = 0.9 |
595 | 596 |
|
596 | | - loss = sphere_loss(next(iter(model.parameters()))) |
| 597 | + parameters = iter(model.parameters()) |
| 598 | + |
| 599 | + loss = sphere_loss(next(parameters)) |
| 600 | + optimizer.grad_norm(loss) |
| 601 | + optimizer.fused_backward(loss, lr=0.1) |
| 602 | + |
| 603 | + loss = sphere_loss(next(parameters)) |
597 | 604 | optimizer.grad_norm(loss) |
598 | 605 | optimizer.fused_backward(loss, lr=0.1) |
599 | 606 |
|
|
0 commit comments