Skip to content

Commit 5cb6b62

Browse files
authored
Modify test to check ClippedAdam's learning rate decay argument (#9)
1 parent 6a337d5 commit 5cb6b62

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

pyroapi/tests/test_svi.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,11 @@ def model(data=None):
5353

5454

5555
@pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"])
56-
@pytest.mark.parametrize("optim_name", ["Adam", "ClippedAdam"])
57-
def test_optimizer(backend, optim_name, jit):
56+
@pytest.mark.parametrize("optim_name, optim_kwargs", [
57+
("Adam", {"lr": 1e-6}),
58+
("ClippedAdam", {"lr": 1e-6, "lrd": 0.999}),
59+
])
60+
def test_optimizer(backend, optim_name, optim_kwargs, jit):
5861

5962
def model(data):
6063
p = pyro.param("p", ops.tensor(0.5))
@@ -67,7 +70,7 @@ def guide(data):
6770
pyro.get_param_store().clear()
6871
Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
6972
elbo = Elbo(ignore_jit_warnings=True)
70-
optimizer = getattr(optim, optim_name)({"lr": 1e-6})
73+
optimizer = getattr(optim, optim_name)(optim_kwargs.copy())
7174
inference = infer.SVI(model, guide, optimizer, elbo)
7275
for i in range(2):
7376
inference.step(data)

0 commit comments

Comments
 (0)