Skip to content

Commit 485dc21

Browse files
committed
Handle required static shape in vip_reparametrize
1 parent d072b94 commit 485dc21

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pymc_experimental/tests/model/transforms/test_autoreparam.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ def test_multilevel():
7070
# multilevel modelling
7171
a = pm.Normal("a")
7272
s = pm.HalfNormal("s")
73-
a_g = pm.Normal("a_g", a, s, dims="level")
73+
a_g = pm.Normal("a_g", a, s, shape=(2,), dims="level")
7474
s_g = pm.HalfNormal("s_g")
75-
a_ig = pm.Normal("a_ig", a_g, s_g, dims=("county", "level"))
75+
a_ig = pm.Normal("a_ig", a_g, s_g, shape=(2, 2), dims=("county", "level"))
7676

7777
model_r, vip = vip_reparametrize(model, ["a_g", "a_ig"])
7878
assert "a_g" in vip.get_lambda()

0 commit comments

Comments
 (0)