@@ -153,6 +153,16 @@ def test_get_jaxified_logp():
153
153
assert not np .isinf (jax_fn ((np .array (5000.0 ), np .array (5000.0 ))))
154
154
155
155
156
+ @pytest .fixture
157
+ def model_test_idata_kwargs (scope = "module" ):
158
+ with pm .Model (coords = {"x_coord" : ["a" , "b" ], "x_coord2" : [1 , 2 ]}) as m :
159
+ x = pm .Normal ("x" , shape = (2 ,), dims = ["x_coord" ])
160
+ y = pm .Normal ("y" , x , observed = [0 , 0 ])
161
+ pm .ConstantData ("constantdata" , [1 , 2 , 3 ])
162
+ pm .MutableData ("mutabledata" , 2 )
163
+ return m
164
+
165
+
156
166
@pytest .mark .parametrize (
157
167
"sampler" ,
158
168
[
@@ -165,15 +175,17 @@ def test_get_jaxified_logp():
165
175
[
166
176
dict (),
167
177
dict (log_likelihood = False ),
178
+ # Overwrite models coords
179
+ dict (coords = {"x_coord" : ["x1" , "x2" ]}),
180
+ # Overwrite dims from dist specification in model
181
+ dict (dims = {"x" : ["x_coord2" ]}),
182
+ # Overwrite both coords and dims
183
+ dict (coords = {"x_coord3" : ["A" , "B" ]}, dims = {"x" : ["x_coord3" ]}),
168
184
],
169
185
)
170
186
@pytest .mark .parametrize ("postprocessing_backend" , [None , "cpu" ])
171
- def test_idata_kwargs (sampler , idata_kwargs , postprocessing_backend ):
172
- with pm .Model () as m :
173
- x = pm .Normal ("x" )
174
- y = pm .Normal ("y" , x , observed = 0 )
175
- pm .ConstantData ("constantdata" , [1 , 2 , 3 ])
176
- pm .MutableData ("mutabledata" , 2 )
187
+ def test_idata_kwargs (model_test_idata_kwargs , sampler , idata_kwargs , postprocessing_backend ):
188
+ with model_test_idata_kwargs :
177
189
idata = sampler (
178
190
tune = 50 ,
179
191
draws = 50 ,
@@ -189,6 +201,12 @@ def test_idata_kwargs(sampler, idata_kwargs, postprocessing_backend):
189
201
else :
190
202
assert "log_likelihood" not in idata
191
203
204
+ x_dim_expected = idata_kwargs .get ("dims" , model_test_idata_kwargs .RV_dims )["x" ][0 ]
205
+ assert idata .posterior .x .dims [- 1 ] == x_dim_expected
206
+
207
+ x_coords_expected = idata_kwargs .get ("coords" , model_test_idata_kwargs .coords )[x_dim_expected ]
208
+ assert list (x_coords_expected ) == list (idata .posterior .x .coords [x_dim_expected ].values )
209
+
192
210
193
211
def test_get_batched_jittered_initial_points ():
194
212
with pm .Model () as model :
0 commit comments