We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 6a337d5 commit 5cb6b62Copy full SHA for 5cb6b62
pyroapi/tests/test_svi.py
@@ -53,8 +53,11 @@ def model(data=None):
53
54
55
@pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"])
56
-@pytest.mark.parametrize("optim_name", ["Adam", "ClippedAdam"])
57
-def test_optimizer(backend, optim_name, jit):
+@pytest.mark.parametrize("optim_name, optim_kwargs", [
+ ("Adam", {"lr": 1e-6}),
58
+ ("ClippedAdam", {"lr": 1e-6, "lrd": 0.999}),
59
+])
60
+def test_optimizer(backend, optim_name, optim_kwargs, jit):
61
62
def model(data):
63
p = pyro.param("p", ops.tensor(0.5))
@@ -67,7 +70,7 @@ def guide(data):
67
70
pyro.get_param_store().clear()
68
71
Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
69
72
elbo = Elbo(ignore_jit_warnings=True)
- optimizer = getattr(optim, optim_name)({"lr": 1e-6})
73
+ optimizer = getattr(optim, optim_name)(optim_kwargs.copy())
74
inference = infer.SVI(model, guide, optimizer, elbo)
75
for i in range(2):
76
inference.step(data)
0 commit comments