@@ -269,34 +269,23 @@ def test_measurable_join_univariate(size1, size2, axis, concatenate):
269
269
270
270
271
271
@pytest .mark .parametrize (
272
- "size1, supp_size1, size2, supp_size2, axis, concatenate" ,
272
+ "size1, supp_size1, size2, supp_size2, axis, concatenate, logp_axis " ,
273
273
[
274
- (None , 2 , None , 2 , 0 , True ),
275
- (None , 2 , None , 2 , - 1 , True ),
276
- ((5 ,), 2 , (3 ,), 2 , 0 , True ),
277
- ((5 ,), 2 , (3 ,), 2 , - 2 , True ),
278
- ((2 ,), 5 , (2 ,), 3 , 1 , True ),
279
- pytest .param (
280
- (2 ,),
281
- 5 ,
282
- (2 ,),
283
- 5 ,
284
- 0 ,
285
- False ,
286
- marks = pytest .mark .xfail (reason = "cannot measure dimshuffled multivariate RVs" ),
287
- ),
288
- pytest .param (
289
- (2 ,),
290
- 5 ,
291
- (2 ,),
292
- 5 ,
293
- 1 ,
294
- False ,
295
- marks = pytest .mark .xfail (reason = "cannot measure dimshuffled multivariate RVs" ),
296
- ),
274
+ (None , 2 , None , 2 , 0 , True , 0 ),
275
+ (None , 2 , None , 2 , - 1 , True , 0 ),
276
+ ((5 ,), 2 , (3 ,), 2 , 0 , True , 0 ),
277
+ ((5 ,), 2 , (3 ,), 2 , - 2 , True , 0 ),
278
+ ((2 ,), 5 , (2 ,), 3 , 1 , True , 0 ),
279
+ ((5 , 6 ), 10 , (5 , 1 ), 10 , 1 , True , 1 ),
280
+ ((5 , 6 ), 10 , (5 , 1 ), 10 , - 2 , True , 1 ),
281
+ ((2 ,), 5 , (2 ,), 5 , 0 , False , 0 ),
282
+ ((2 ,), 5 , (2 ,), 5 , 1 , False , 1 ),
283
+ ((5 , 6 ), 10 , (5 , 6 ), 10 , 2 , False , 2 ),
297
284
],
298
285
)
299
- def test_measurable_join_multivariate (size1 , supp_size1 , size2 , supp_size2 , axis , concatenate ):
286
+ def test_measurable_join_multivariate (
287
+ size1 , supp_size1 , size2 , supp_size2 , axis , concatenate , logp_axis
288
+ ):
300
289
base1_rv = pt .random .multivariate_normal (
301
290
np .zeros (supp_size1 ), np .eye (supp_size1 ), size = size1 , name = "base1"
302
291
)
@@ -310,19 +299,18 @@ def test_measurable_join_multivariate(size1, supp_size1, size2, supp_size2, axis
310
299
base1_vv = base1_rv .clone ()
311
300
base2_vv = base2_rv .clone ()
312
301
y_vv = y_rv .clone ()
302
+
303
+ y_logp = logp (y_rv , y_vv )
304
+ assert_no_rvs (y_logp )
305
+
313
306
base_logps = [
314
307
pt .atleast_1d (logp )
315
308
for logp in conditional_logp ({base1_rv : base1_vv , base2_rv : base2_vv }).values ()
316
309
]
317
-
318
310
if concatenate :
319
- axis_norm = np .core .numeric .normalize_axis_index (axis , base1_rv .ndim )
320
- base_logps = pt .concatenate (base_logps , axis = axis_norm - 1 )
311
+ expected_logp = pt .concatenate (base_logps , axis = logp_axis )
321
312
else :
322
- axis_norm = np .core .numeric .normalize_axis_index (axis , base1_rv .ndim + 1 )
323
- base_logps = pt .stack (base_logps , axis = axis_norm - 1 )
324
- y_logp = y_logp = logp (y_rv , y_vv )
325
- assert_no_rvs (y_logp )
313
+ expected_logp = pt .stack (base_logps , axis = logp_axis )
326
314
327
315
base1_testval = base1_rv .eval ()
328
316
base2_testval = base2_rv .eval ()
@@ -331,7 +319,7 @@ def test_measurable_join_multivariate(size1, supp_size1, size2, supp_size2, axis
331
319
else :
332
320
y_testval = np .stack ((base1_testval , base2_testval ), axis = axis )
333
321
np .testing .assert_allclose (
334
- base_logps .eval ({base1_vv : base1_testval , base2_vv : base2_testval }),
322
+ expected_logp .eval ({base1_vv : base1_testval , base2_vv : base2_testval }),
335
323
y_logp .eval ({y_vv : y_testval }),
336
324
)
337
325
0 commit comments