@@ -253,9 +253,54 @@ def test_insert_draw(self):
253
253
assert v1 == 12
254
254
numpy .testing .assert_array_equal (v2 , draw ["v2" ])
255
255
numpy .testing .assert_array_equal (v3 , draw ["v3" ])
256
+ pass
257
+
258
+ @pytest .mark .xfail (reason = "issue #56" )
259
+ def test_get_row_at (self ):
260
+ run , chains = fully_initialized (
261
+ self .backend ,
262
+ make_runmeta (
263
+ variables = [
264
+ Variable ("v1" , "uint16" , []),
265
+ Variable ("v2" , "float32" , list ((3 ,))),
266
+ ],
267
+ ),
268
+ )
269
+ chain = chains [0 ]
270
+ for i in range (10 ):
271
+ chain .append (dict (v1 = i , v2 = numpy .array ([i , 2 , 3 ])))
272
+ assert len (chain ) == 10
273
+
274
+ row5 = chain .get_draws_at (5 , ["v1" , "v2" ])
275
+ assert "v1" in row5
276
+ assert "v2" in row5
277
+ assert row5 ["v1" ] == 5
278
+ assert tuple (row5 ["v2" ]) == (5 , 2 , 3 )
256
279
257
280
with pytest .raises (Exception , match = "No record found for draw" ):
258
- chain ._get_row_at (2 , var_names = ["v1" ])
281
+ chain ._get_row_at (20 , var_names = ["v1" ])
282
+
283
+ # Issue #56 was caused by querying just one variable
284
+ assert len (chain ._get_row_at (5 , var_names = ["v1" ])) == 1
285
+ pass
286
+
287
+ @pytest .mark .xfail (reason = "issue #37" )
288
+ def test_exotic_var_names (self ):
289
+ run , chains = fully_initialized (
290
+ self .backend ,
291
+ make_runmeta (
292
+ variables = [
293
+ Variable ("v1[a]" , "uint16" , []),
294
+ ],
295
+ ),
296
+ )
297
+ chain = chains [0 ]
298
+ for i in range (10 ):
299
+ chain .append ({var .name : i for var in run .meta .variables })
300
+ assert len (chain ) == 10
301
+
302
+ row2 = chain ._get_row_at (2 , var_names = ["v1[a]" ])
303
+ assert "v1[a]" in row2
259
304
pass
260
305
261
306
def test_to_inferencedata_equalize_chain_lengths (self , caplog ):
0 commit comments