1
- from typing import Any , Dict
1
+ from typing import Any , Callable , Dict , Optional
2
2
from unittest import mock
3
3
4
4
import aesara
5
5
import aesara .tensor as at
6
+ import arviz as az
6
7
import jax
7
8
import numpy as np
8
9
import pytest
@@ -159,11 +160,14 @@ def test_get_jaxified_logp():
159
160
assert not np .isinf (jax_fn ((np .array (5000.0 ), np .array (5000.0 ))))
160
161
161
162
162
- @pytest .fixture
163
- def model_test_idata_kwargs (scope = "module" ):
164
- with pm .Model (coords = {"x_coord" : ["a" , "b" ], "x_coord2" : [1 , 2 ]}) as m :
163
+ @pytest .fixture (scope = "module" )
164
+ def model_test_idata_kwargs () -> pm .Model :
165
+ with pm .Model (
166
+ coords = {"x_coord" : ["a" , "b" ], "x_coord2" : [1 , 2 ], "z_coord" : ["apple" , "banana" , "orange" ]}
167
+ ) as m :
165
168
x = pm .Normal ("x" , shape = (2 ,), dims = ["x_coord" ])
166
- y = pm .Normal ("y" , x , observed = [0 , 0 ])
169
+ _ = pm .Normal ("y" , x , observed = [0 , 0 ])
170
+ _ = pm .Normal ("z" , 0 , 1 , dims = "z_coord" )
167
171
pm .ConstantData ("constantdata" , [1 , 2 , 3 ])
168
172
pm .MutableData ("mutabledata" , 2 )
169
173
return m
@@ -190,7 +194,13 @@ def model_test_idata_kwargs(scope="module"):
190
194
],
191
195
)
192
196
@pytest .mark .parametrize ("postprocessing_backend" , [None , "cpu" ])
193
- def test_idata_kwargs (model_test_idata_kwargs , sampler , idata_kwargs , postprocessing_backend ):
197
+ def test_idata_kwargs (
198
+ model_test_idata_kwargs : pm .Model ,
199
+ sampler : Callable [..., az .InferenceData ],
200
+ idata_kwargs : Dict [str , Any ],
201
+ postprocessing_backend : Optional [str ],
202
+ ):
203
+ idata : Optional [az .InferenceData ] = None
194
204
with model_test_idata_kwargs :
195
205
idata = sampler (
196
206
tune = 50 ,
@@ -199,19 +209,31 @@ def test_idata_kwargs(model_test_idata_kwargs, sampler, idata_kwargs, postproces
199
209
idata_kwargs = idata_kwargs ,
200
210
postprocessing_backend = postprocessing_backend ,
201
211
)
202
- assert "constantdata" in idata .constant_data
203
- assert "mutabledata" in idata .constant_data
212
+ assert idata is not None
213
+ const_data = idata .get ("constant_data" )
214
+ assert const_data is not None
215
+ assert "constantdata" in const_data
216
+ assert "mutabledata" in const_data
204
217
205
218
if idata_kwargs .get ("log_likelihood" , True ):
206
219
assert "log_likelihood" in idata
207
220
else :
208
221
assert "log_likelihood" not in idata
209
222
223
+ posterior = idata .get ("posterior" )
224
+ assert posterior is not None
210
225
x_dim_expected = idata_kwargs .get ("dims" , model_test_idata_kwargs .RV_dims )["x" ][0 ]
211
- assert idata .posterior .x .dims [- 1 ] == x_dim_expected
226
+ assert x_dim_expected is not None
227
+ assert posterior ["x" ].dims [- 1 ] == x_dim_expected
212
228
213
229
x_coords_expected = idata_kwargs .get ("coords" , model_test_idata_kwargs .coords )[x_dim_expected ]
214
- assert list (x_coords_expected ) == list (idata .posterior .x .coords [x_dim_expected ].values )
230
+ assert x_coords_expected is not None
231
+ assert list (x_coords_expected ) == list (posterior ["x" ].coords [x_dim_expected ].values )
232
+
233
+ assert posterior ["z" ].dims [2 ] == "z_coord"
234
+ assert np .all (
235
+ posterior ["z" ].coords ["z_coord" ].values == np .array (["apple" , "banana" , "orange" ])
236
+ )
215
237
216
238
217
239
def test_get_batched_jittered_initial_points ():
0 commit comments