@@ -103,3 +103,29 @@ def test_set_truncate(model_c: pm.Model):
103
103
vip .truncate_lambda (g = 0.2 )
104
104
np .testing .assert_allclose (vip .get_lambda ()["g" ], 1 )
105
105
np .testing .assert_allclose (vip .get_lambda ()["m" ], 0.9 )
106
+
107
+
108
+ @pytest .mark .xfail (reason = "FIX shape computation for lambda" )
109
+ def test_lambda_shape ():
110
+ with pm .Model (coords = dict (a = [1 , 2 ])) as model :
111
+ b1 = pm .Normal ("b1" , dims = "a" )
112
+ b2 = pm .Normal ("b2" , shape = 2 )
113
+ b3 = pm .Normal ("b3" , size = 2 )
114
+ b4 = pm .Normal ("b4" , np .asarray ([1 , 2 ]))
115
+ model_v , vip = vip_reparametrize (model , ["b1" , "b2" , "b3" , "b4" ])
116
+ lams = vip .get_lambda ()
117
+ for v in ["b1" , "b2" , "b3" , "b4" ]:
118
+ assert lams [v ].shape == (2 ,), v
119
+
120
+
121
+ @pytest .mark .xfail (reason = "FIX shape computation for lambda" )
122
+ def test_lambda_shape_transformed_1d ():
123
+ with pm .Model (coords = dict (a = [1 , 2 ])) as model :
124
+ b1 = pm .Exponential ("b1" , 1 , dims = "a" )
125
+ b2 = pm .Exponential ("b2" , 1 , shape = 2 )
126
+ b3 = pm .Exponential ("b3" , 1 , size = 2 )
127
+ b4 = pm .Exponential ("b4" , np .asarray ([1 , 2 ]))
128
+ model_v , vip = vip_reparametrize (model , ["b1" , "b2" , "b3" , "b4" ])
129
+ lams = vip .get_lambda ()
130
+ for v in ["b1" , "b2" , "b3" , "b4" ]:
131
+ assert lams [v ].shape == (2 ,), v
0 commit comments