Skip to content

Commit bcdea46

Browse files
committed
add failing test for all cases of how shape can go
1 parent d5355a5 commit bcdea46

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

tests/model/transforms/test_autoreparam.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,29 @@ def test_set_truncate(model_c: pm.Model):
103103
vip.truncate_lambda(g=0.2)
104104
np.testing.assert_allclose(vip.get_lambda()["g"], 1)
105105
np.testing.assert_allclose(vip.get_lambda()["m"], 0.9)
106+
107+
108+
@pytest.mark.xfail(reason="FIX shape computation for lambda")
109+
def test_lambda_shape():
110+
with pm.Model(coords=dict(a=[1, 2])) as model:
111+
b1 = pm.Normal("b1", dims="a")
112+
b2 = pm.Normal("b2", shape=2)
113+
b3 = pm.Normal("b3", size=2)
114+
b4 = pm.Normal("b4", np.asarray([1, 2]))
115+
model_v, vip = vip_reparametrize(model, ["b1", "b2", "b3", "b4"])
116+
lams = vip.get_lambda()
117+
for v in ["b1", "b2", "b3", "b4"]:
118+
assert lams[v].shape == (2,), v
119+
120+
121+
@pytest.mark.xfail(reason="FIX shape computation for lambda")
122+
def test_lambda_shape_transformed_1d():
123+
with pm.Model(coords=dict(a=[1, 2])) as model:
124+
b1 = pm.Exponential("b1", 1, dims="a")
125+
b2 = pm.Exponential("b2", 1, shape=2)
126+
b3 = pm.Exponential("b3", 1, size=2)
127+
b4 = pm.Exponential("b4", np.asarray([1, 2]))
128+
model_v, vip = vip_reparametrize(model, ["b1", "b2", "b3", "b4"])
129+
lams = vip.get_lambda()
130+
for v in ["b1", "b2", "b3", "b4"]:
131+
assert lams[v].shape == (2,), v

0 commit comments

Comments
 (0)