Skip to content

Commit c3b4309

Browse files
committed
update: test_lomo_optimizer
1 parent 2e5fb36 commit c3b4309

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

tests/test_optimizers.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
PULLBACK_MOMENTUM,
2929
)
3030
from tests.utils import (
31+
Example,
3132
MultiHeadLogisticRegression,
3233
build_environment,
3334
dummy_closure,
@@ -578,22 +579,28 @@ def test_lomo_fused_backward(optimizer_name, environment):
578579

579580
@pytest.mark.parametrize('optimizer_name', ['lomo', 'adalomo'])
580581
@pytest.mark.parametrize('precision', [16, 32])
581-
def test_lomo_optimizer(optimizer_name, precision, environment):
582-
_, model, _ = environment
582+
def test_lomo_optimizer(optimizer_name, precision):
583+
model = Example()
583584

584-
model.fc1.bias.data = torch.randn(2, dtype=torch.float32)
585-
model.fc1.bias.grad = torch.zeros(2, dtype=torch.float32)
585+
model.fc1.bias.data = torch.randn(1, dtype=torch.float32)
586+
model.fc1.bias.grad = torch.randn(1, dtype=torch.float32)
586587

587588
if precision == 16:
588-
model.fc1.weight.data = torch.randn(2, 2, dtype=torch.float16)
589-
model.fc1.weight.grad = torch.zeros(2, 2, dtype=torch.float16)
589+
model.fc1.weight.data = torch.randn(1, 1, dtype=torch.float16)
590+
model.fc1.weight.grad = torch.zeros(1, 1, dtype=torch.float16)
590591

591592
optimizer = load_optimizer(optimizer_name)(model, clip_grad_norm=1.0, clip_grad_value=1.0)
592593

593594
if precision == 16:
594595
optimizer.clip_coef = 0.9
595596

596-
loss = sphere_loss(next(iter(model.parameters())))
597+
parameters = iter(model.parameters())
598+
599+
loss = sphere_loss(next(parameters))
600+
optimizer.grad_norm(loss)
601+
optimizer.fused_backward(loss, lr=0.1)
602+
603+
loss = sphere_loss(next(parameters))
597604
optimizer.grad_norm(loss)
598605
optimizer.fused_backward(loss, lr=0.1)
599606

0 commit comments

Comments
 (0)