From 7287836fd8091767d20e7c03aa79db00ac0b13fe Mon Sep 17 00:00:00 2001 From: ferres Date: Fri, 19 Jul 2024 10:09:03 +0000 Subject: [PATCH 1/5] feature: Add Exponential distribution to autoreparameterization dispatch table --- .../model/transforms/autoreparam.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/pymc_experimental/model/transforms/autoreparam.py b/pymc_experimental/model/transforms/autoreparam.py index cc1f78289..bb3996459 100644 --- a/pymc_experimental/model/transforms/autoreparam.py +++ b/pymc_experimental/model/transforms/autoreparam.py @@ -246,6 +246,44 @@ def _( return vip_rep +@_vip_reparam_node.register +def _( + op: pm.Exponential, + node: Apply, + name: str, + dims: List[Variable], + transform: Optional[Transform], + lam: pt.TensorVariable, +) -> ModelDeterministic: + rng, size, scale = node.inputs + scale_centered = scale**lam + scale_noncentered = scale ** (1 - lam) + vip_rv_ = pm.Exponential.dist( + scale=scale_centered, + size=size, + rng=rng, + ) + vip_rv_value_ = vip_rv_.clone() + vip_rv_.name = f"{name}::tau_" + if transform is not None: + vip_rv_value_.name = f"{vip_rv_.name}_{transform.name}__" + else: + vip_rv_value_.name = vip_rv_.name + vip_rv = model_free_rv( + vip_rv_, + vip_rv_value_, + transform, + *dims, + ) + + vip_rep_ = scale_noncentered * vip_rv + + vip_rep_.name = name + + vip_rep = model_deterministic(vip_rep_, *dims) + return vip_rep + + def vip_reparametrize( model: pm.Model, var_names: Sequence[str], From 0fe63dc721d03624424caa443492371f88e97c3c Mon Sep 17 00:00:00 2001 From: ferres Date: Fri, 19 Jul 2024 10:22:07 +0000 Subject: [PATCH 2/5] test: add tests for autoreparam --- tests/model/transforms/test_autoreparam.py | 39 ++++++++++++---------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/tests/model/transforms/test_autoreparam.py b/tests/model/transforms/test_autoreparam.py index b2ea245ae..98558ebb1 100644 --- a/tests/model/transforms/test_autoreparam.py +++ b/tests/model/transforms/test_autoreparam.py @@ -11,6 +11,7 @@ def model_c(): m = pm.Normal("m") s = pm.LogNormal("s") pm.Normal("g", m, s, shape=5) + pm.Exponential("e", scale=s, shape=7) return mod @@ -20,31 +21,34 @@ def model_nc(): m = pm.Normal("m") s = pm.LogNormal("s") pm.Deterministic("g", pm.Normal("z", shape=5) * s + m) + pm.Deterministic("e", pm.Exponential("z_e", 1, shape=7) * s) return mod -def test_reparametrize_created(model_c: pm.Model): - model_reparam, vip = vip_reparametrize(model_c, ["g"]) - assert "g" in vip.get_lambda() - assert "g::lam_logit__" in model_reparam.named_vars - assert "g::tau_" in model_reparam.named_vars +@pytest.mark.parameterize("var", ["g", "e"]) +def test_reparametrize_created(model_c: pm.Model, var): + model_reparam, vip = vip_reparametrize(model_c, [var]) + assert f"{var}" in vip.get_lambda() + assert f"{var}::lam_logit__" in model_reparam.named_vars + assert f"{var}::tau_" in model_reparam.named_vars vip.set_all_lambda(1) - assert ~np.isfinite(model_reparam["g::lam_logit__"].get_value()).any() + assert ~np.isfinite(model_reparam[f"{var}::lam_logit__"].get_value()).any() -def test_random_draw(model_c: pm.Model, model_nc): +@pytest.mark.parameterize("var", ["g", "e"]) +def test_random_draw(model_c: pm.Model, model_nc, var): model_c = pm.do(model_c, {"m": 3, "s": 2}) model_nc = pm.do(model_nc, {"m": 3, "s": 2}) - model_v, vip = vip_reparametrize(model_c, ["g"]) - assert "g" in [v.name for v in model_v.deterministics] - c = pm.draw(model_c["g"], random_seed=42, draws=1000) - nc = pm.draw(model_nc["g"], random_seed=42, draws=1000) + model_v, vip = vip_reparametrize(model_c, [var]) + assert var in [v.name for v in model_v.deterministics] + c = pm.draw(model_c[var], random_seed=42, draws=1000) + nc = pm.draw(model_nc[var], random_seed=42, draws=1000) vip.set_all_lambda(1) - v_1 = pm.draw(model_v["g"], random_seed=42, draws=1000) + v_1 = pm.draw(model_v[var], random_seed=42, draws=1000) vip.set_all_lambda(0) - v_0 = pm.draw(model_v["g"], random_seed=42, draws=1000) + v_0 = pm.draw(model_v[var], random_seed=42, draws=1000) vip.set_all_lambda(0.5) - v_05 = pm.draw(model_v["g"], random_seed=42, draws=1000) + v_05 = pm.draw(model_v[var], random_seed=42, draws=1000) np.testing.assert_allclose(c.mean(), nc.mean()) np.testing.assert_allclose(c.mean(), v_0.mean()) np.testing.assert_allclose(v_05.mean(), v_1.mean()) @@ -56,11 +60,12 @@ def test_random_draw(model_c: pm.Model, model_nc): np.testing.assert_allclose(v_1.std(), nc.std()) -def test_reparam_fit(model_c): - model_v, vip = vip_reparametrize(model_c, ["g"]) +@pytest.mark.parameterize("var", ["g", "e"]) +def test_reparam_fit(model_c, var): + model_v, vip = vip_reparametrize(model_c, [var]) with model_v: vip.fit(random_seed=42) - np.testing.assert_allclose(vip.get_lambda()["g"], 0, atol=0.01) + np.testing.assert_allclose(vip.get_lambda()[var], 0, atol=0.01) def test_multilevel(): From 74d8c02c131b23f60eed4fd8485530c1c9175a7f Mon Sep 17 00:00:00 2001 From: ferres Date: Fri, 19 Jul 2024 10:31:34 +0000 Subject: [PATCH 3/5] fix: Fix typo --- tests/model/transforms/test_autoreparam.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/model/transforms/test_autoreparam.py b/tests/model/transforms/test_autoreparam.py index 98558ebb1..5ed95a7f6 100644 --- a/tests/model/transforms/test_autoreparam.py +++ b/tests/model/transforms/test_autoreparam.py @@ -25,7 +25,7 @@ def model_nc(): return mod -@pytest.mark.parameterize("var", ["g", "e"]) +@pytest.mark.parametrize("var", ["g", "e"]) def test_reparametrize_created(model_c: pm.Model, var): model_reparam, vip = vip_reparametrize(model_c, [var]) assert f"{var}" in vip.get_lambda() @@ -35,7 +35,7 @@ def test_reparametrize_created(model_c: pm.Model, var): assert ~np.isfinite(model_reparam[f"{var}::lam_logit__"].get_value()).any() -@pytest.mark.parameterize("var", ["g", "e"]) +@pytest.mark.parametrize("var", ["g", "e"]) def test_random_draw(model_c: pm.Model, model_nc, var): model_c = pm.do(model_c, {"m": 3, "s": 2}) model_nc = pm.do(model_nc, {"m": 3, "s": 2}) @@ -60,7 +60,7 @@ def test_random_draw(model_c: pm.Model, model_nc, var): np.testing.assert_allclose(v_1.std(), nc.std()) -@pytest.mark.parameterize("var", ["g", "e"]) +@pytest.mark.parametrize("var", ["g", "e"]) def test_reparam_fit(model_c, var): model_v, vip = vip_reparametrize(model_c, [var]) with model_v: From 162e496d44ad7f8b42009164da9ea8bdf012b40f Mon Sep 17 00:00:00 2001 From: ferres Date: Fri, 19 Jul 2024 11:12:50 +0000 Subject: [PATCH 4/5] reparam all needed variables when doing fit --- tests/model/transforms/test_autoreparam.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/model/transforms/test_autoreparam.py b/tests/model/transforms/test_autoreparam.py index 5ed95a7f6..d236a60f4 100644 --- a/tests/model/transforms/test_autoreparam.py +++ b/tests/model/transforms/test_autoreparam.py @@ -60,12 +60,13 @@ def test_random_draw(model_c: pm.Model, model_nc, var): np.testing.assert_allclose(v_1.std(), nc.std()) -@pytest.mark.parametrize("var", ["g", "e"]) -def test_reparam_fit(model_c, var): - model_v, vip = vip_reparametrize(model_c, [var]) +def test_reparam_fit(model_c): + vars = ["g", "e"] + model_v, vip = vip_reparametrize(model_c, ["g", "e"]) with model_v: vip.fit(random_seed=42) - np.testing.assert_allclose(vip.get_lambda()[var], 0, atol=0.01) + for var in vars: + np.testing.assert_allclose(vip.get_lambda()[var], 0, atol=0.01) def test_multilevel(): From 17fe72c8265e241683c376452da52e5e6bbb4b8e Mon Sep 17 00:00:00 2001 From: ferres Date: Fri, 19 Jul 2024 14:31:35 +0000 Subject: [PATCH 5/5] test: Increase number of iteratins for variational test --- tests/model/transforms/test_autoreparam.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/model/transforms/test_autoreparam.py b/tests/model/transforms/test_autoreparam.py index d236a60f4..1d2173066 100644 --- a/tests/model/transforms/test_autoreparam.py +++ b/tests/model/transforms/test_autoreparam.py @@ -64,7 +64,7 @@ def test_reparam_fit(model_c): vars = ["g", "e"] model_v, vip = vip_reparametrize(model_c, ["g", "e"]) with model_v: - vip.fit(random_seed=42) + vip.fit(50000, random_seed=42) for var in vars: np.testing.assert_allclose(vip.get_lambda()[var], 0, atol=0.01)