Skip to content

Commit 17dca13

Browse files
ricardoV94twiecki
authored andcommitted
Reorganize and label Mixture test parametrizations
1 parent 62ef8b6 commit 17dca13

File tree

1 file changed

+107
-39
lines changed

1 file changed

+107
-39
lines changed

pymc/tests/logprob/test_mixture.py

Lines changed: 107 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def test_hetero_mixture_binomial(p_val, size):
226226
(),
227227
0,
228228
),
229-
# Degenerate vector mixture components, scalar index
229+
# Degenerate vector mixture components, scalar index along join axis
230230
(
231231
(
232232
np.array([0], dtype=pytensor.config.floatX),
@@ -246,7 +246,27 @@ def test_hetero_mixture_binomial(p_val, size):
246246
(),
247247
0,
248248
),
249-
# Scalar mixture components, vector index
249+
# Degenerate vector mixture components, scalar index along join axis (axis=1)
250+
(
251+
(
252+
np.array([0], dtype=pytensor.config.floatX),
253+
np.array(1, dtype=pytensor.config.floatX),
254+
),
255+
(
256+
np.array([0.5], dtype=pytensor.config.floatX),
257+
np.array(0.5, dtype=pytensor.config.floatX),
258+
),
259+
(
260+
np.array([100], dtype=pytensor.config.floatX),
261+
np.array(1, dtype=pytensor.config.floatX),
262+
),
263+
np.array([0.1, 0.5, 0.4], dtype=pytensor.config.floatX),
264+
None,
265+
(),
266+
(slice(None),),
267+
1,
268+
),
269+
# Vector mixture components, scalar index along the join axis
250270
(
251271
(
252272
np.array(0, dtype=pytensor.config.floatX),
@@ -261,49 +281,72 @@ def test_hetero_mixture_binomial(p_val, size):
261281
np.array(1, dtype=pytensor.config.floatX),
262282
),
263283
np.array([0.1, 0.5, 0.4], dtype=pytensor.config.floatX),
284+
(4,),
264285
(),
265-
(6,),
266286
(),
267287
0,
268288
),
289+
# Vector mixture components, scalar index along the join axis (axis=1)
269290
(
270291
(
271-
np.array([0, -100], dtype=pytensor.config.floatX),
292+
np.array(0, dtype=pytensor.config.floatX),
272293
np.array(1, dtype=pytensor.config.floatX),
273294
),
274295
(
275-
np.array([0.5, 1], dtype=pytensor.config.floatX),
276-
np.array([0.5, 1], dtype=pytensor.config.floatX),
296+
np.array(0.5, dtype=pytensor.config.floatX),
297+
np.array(0.5, dtype=pytensor.config.floatX),
277298
),
278299
(
279-
np.array([100, 1000], dtype=pytensor.config.floatX),
300+
np.array(100, dtype=pytensor.config.floatX),
280301
np.array(1, dtype=pytensor.config.floatX),
281302
),
282-
np.array([[0.1, 0.5, 0.4], [0.4, 0.1, 0.5]], dtype=pytensor.config.floatX),
283-
(2,),
284-
(2,),
303+
np.array([0.1, 0.5, 0.4], dtype=pytensor.config.floatX),
304+
(4,),
305+
(),
306+
(slice(None),),
307+
1,
308+
),
309+
# Matrix components, scalar index along first axis
310+
(
311+
(
312+
np.array(0, dtype=pytensor.config.floatX),
313+
np.array(1, dtype=pytensor.config.floatX),
314+
),
315+
(
316+
np.array(0.5, dtype=pytensor.config.floatX),
317+
np.array(0.5, dtype=pytensor.config.floatX),
318+
),
319+
(
320+
np.array(100, dtype=pytensor.config.floatX),
321+
np.array(1, dtype=pytensor.config.floatX),
322+
),
323+
np.array([0.1, 0.5, 0.4], dtype=pytensor.config.floatX),
324+
(2, 3),
325+
(),
285326
(),
286327
0,
287328
),
329+
# Scalar mixture components, vector index along first axis
288330
(
289331
(
290-
np.array([0, -100], dtype=pytensor.config.floatX),
332+
np.array(0, dtype=pytensor.config.floatX),
291333
np.array(1, dtype=pytensor.config.floatX),
292334
),
293335
(
294-
np.array([0.5, 1], dtype=pytensor.config.floatX),
295-
np.array([0.5, 1], dtype=pytensor.config.floatX),
336+
np.array(0.5, dtype=pytensor.config.floatX),
337+
np.array(0.5, dtype=pytensor.config.floatX),
296338
),
297339
(
298-
np.array([100, 1000], dtype=pytensor.config.floatX),
340+
np.array(100, dtype=pytensor.config.floatX),
299341
np.array(1, dtype=pytensor.config.floatX),
300342
),
301-
np.array([[0.1, 0.5, 0.4], [0.4, 0.1, 0.5]], dtype=pytensor.config.floatX),
302-
None,
303-
None,
343+
np.array([0.1, 0.5, 0.4], dtype=pytensor.config.floatX),
344+
(),
345+
(6,),
304346
(),
305347
0,
306348
),
349+
# Vector mixture components, vector index along first axis
307350
(
308351
(
309352
np.array(0, dtype=pytensor.config.floatX),
@@ -320,10 +363,31 @@ def test_hetero_mixture_binomial(p_val, size):
320363
np.array([0.1, 0.5, 0.4], dtype=pytensor.config.floatX),
321364
(2,),
322365
(2,),
323-
(),
366+
(slice(None),),
324367
0,
325368
),
326-
# Same as before but with degenerate vector parameters
369+
# Vector mixture components, vector index along last axis
370+
pytest.param(
371+
(
372+
np.array(0, dtype=pytensor.config.floatX),
373+
np.array(1, dtype=pytensor.config.floatX),
374+
),
375+
(
376+
np.array(0.5, dtype=pytensor.config.floatX),
377+
np.array(0.5, dtype=pytensor.config.floatX),
378+
),
379+
(
380+
np.array(100, dtype=pytensor.config.floatX),
381+
np.array(1, dtype=pytensor.config.floatX),
382+
),
383+
np.array([0.1, 0.5, 0.4], dtype=pytensor.config.floatX),
384+
(2,),
385+
(4,),
386+
(slice(None),),
387+
1,
388+
marks=pytest.mark.xfail(IndexError, reason="Bug in AdvancedIndex Mixture logprob"),
389+
),
390+
# Vector mixture components (with degenerate vector parameters), vector index along first axis
327391
(
328392
(
329393
np.array([0], dtype=pytensor.config.floatX),
@@ -343,45 +407,48 @@ def test_hetero_mixture_binomial(p_val, size):
343407
(),
344408
0,
345409
),
410+
# Vector mixture components (with vector parameters), vector index along first axis
346411
(
347412
(
348-
np.array(0, dtype=pytensor.config.floatX),
413+
np.array([0, -100], dtype=pytensor.config.floatX),
349414
np.array(1, dtype=pytensor.config.floatX),
350415
),
351416
(
352-
np.array(0.5, dtype=pytensor.config.floatX),
353-
np.array(0.5, dtype=pytensor.config.floatX),
417+
np.array([0.5, 1], dtype=pytensor.config.floatX),
418+
np.array([0.5, 1], dtype=pytensor.config.floatX),
354419
),
355420
(
356-
np.array(100, dtype=pytensor.config.floatX),
421+
np.array([100, 1000], dtype=pytensor.config.floatX),
357422
np.array(1, dtype=pytensor.config.floatX),
358423
),
359-
np.array([0.1, 0.5, 0.4], dtype=pytensor.config.floatX),
360-
(2, 3),
361-
(2, 3),
424+
np.array([[0.1, 0.5, 0.4], [0.4, 0.1, 0.5]], dtype=pytensor.config.floatX),
425+
(2,),
426+
(2,),
362427
(),
363428
0,
364429
),
430+
# Vector mixture components (with vector parameters), vector index along first axis, implicit sizes
365431
(
366432
(
367-
np.array(0, dtype=pytensor.config.floatX),
433+
np.array([0, -100], dtype=pytensor.config.floatX),
368434
np.array(1, dtype=pytensor.config.floatX),
369435
),
370436
(
371-
np.array(0.5, dtype=pytensor.config.floatX),
372-
np.array(0.5, dtype=pytensor.config.floatX),
437+
np.array([0.5, 1], dtype=pytensor.config.floatX),
438+
np.array([0.5, 1], dtype=pytensor.config.floatX),
373439
),
374440
(
375-
np.array(100, dtype=pytensor.config.floatX),
441+
np.array([100, 1000], dtype=pytensor.config.floatX),
376442
np.array(1, dtype=pytensor.config.floatX),
377443
),
378-
np.array([0.1, 0.5, 0.4], dtype=pytensor.config.floatX),
379-
(2, 3),
380-
(),
444+
np.array([[0.1, 0.5, 0.4], [0.4, 0.1, 0.5]], dtype=pytensor.config.floatX),
445+
None,
446+
None,
381447
(),
382448
0,
383449
),
384-
pytest.param(
450+
# Matrix mixture components, matrix index
451+
(
385452
(
386453
np.array(0, dtype=pytensor.config.floatX),
387454
np.array(1, dtype=pytensor.config.floatX),
@@ -395,12 +462,12 @@ def test_hetero_mixture_binomial(p_val, size):
395462
np.array(1, dtype=pytensor.config.floatX),
396463
),
397464
np.array([0.1, 0.5, 0.4], dtype=pytensor.config.floatX),
398-
(3,),
399-
(3,),
400-
(slice(None),),
401-
1,
402-
marks=pytest.mark.xfail(IndexError, reason="Bug in AdvancedIndex Mixture logprob"),
465+
(2, 3),
466+
(2, 3),
467+
(),
468+
0,
403469
),
470+
# Vector components, matrix indexing (constant along first dimension, then random)
404471
(
405472
(
406473
np.array(0, dtype=pytensor.config.floatX),
@@ -420,6 +487,7 @@ def test_hetero_mixture_binomial(p_val, size):
420487
(np.arange(5),),
421488
0,
422489
),
490+
# Vector mixture components, tensor3 indexing (constant along first dimension, then degenerate, then random)
423491
(
424492
(
425493
np.array(0, dtype=pytensor.config.floatX),

0 commit comments

Comments
 (0)