@@ -226,7 +226,7 @@ def test_hetero_mixture_binomial(p_val, size):
226
226
(),
227
227
0 ,
228
228
),
229
- # Degenerate vector mixture components, scalar index
229
+ # Degenerate vector mixture components, scalar index along join axis
230
230
(
231
231
(
232
232
np .array ([0 ], dtype = pytensor .config .floatX ),
@@ -246,7 +246,27 @@ def test_hetero_mixture_binomial(p_val, size):
246
246
(),
247
247
0 ,
248
248
),
249
- # Scalar mixture components, vector index
249
+ # Degenerate vector mixture components, scalar index along join axis (axis=1)
250
+ (
251
+ (
252
+ np .array ([0 ], dtype = pytensor .config .floatX ),
253
+ np .array (1 , dtype = pytensor .config .floatX ),
254
+ ),
255
+ (
256
+ np .array ([0.5 ], dtype = pytensor .config .floatX ),
257
+ np .array (0.5 , dtype = pytensor .config .floatX ),
258
+ ),
259
+ (
260
+ np .array ([100 ], dtype = pytensor .config .floatX ),
261
+ np .array (1 , dtype = pytensor .config .floatX ),
262
+ ),
263
+ np .array ([0.1 , 0.5 , 0.4 ], dtype = pytensor .config .floatX ),
264
+ None ,
265
+ (),
266
+ (slice (None ),),
267
+ 1 ,
268
+ ),
269
+ # Vector mixture components, scalar index along the join axis
250
270
(
251
271
(
252
272
np .array (0 , dtype = pytensor .config .floatX ),
@@ -261,49 +281,72 @@ def test_hetero_mixture_binomial(p_val, size):
261
281
np .array (1 , dtype = pytensor .config .floatX ),
262
282
),
263
283
np .array ([0.1 , 0.5 , 0.4 ], dtype = pytensor .config .floatX ),
284
+ (4 ,),
264
285
(),
265
- (6 ,),
266
286
(),
267
287
0 ,
268
288
),
289
+ # Vector mixture components, scalar index along the join axis (axis=1)
269
290
(
270
291
(
271
- np .array ([ 0 , - 100 ] , dtype = pytensor .config .floatX ),
292
+ np .array (0 , dtype = pytensor .config .floatX ),
272
293
np .array (1 , dtype = pytensor .config .floatX ),
273
294
),
274
295
(
275
- np .array ([ 0.5 , 1 ] , dtype = pytensor .config .floatX ),
276
- np .array ([ 0.5 , 1 ] , dtype = pytensor .config .floatX ),
296
+ np .array (0.5 , dtype = pytensor .config .floatX ),
297
+ np .array (0.5 , dtype = pytensor .config .floatX ),
277
298
),
278
299
(
279
- np .array ([ 100 , 1000 ] , dtype = pytensor .config .floatX ),
300
+ np .array (100 , dtype = pytensor .config .floatX ),
280
301
np .array (1 , dtype = pytensor .config .floatX ),
281
302
),
282
- np .array ([[0.1 , 0.5 , 0.4 ], [0.4 , 0.1 , 0.5 ]], dtype = pytensor .config .floatX ),
283
- (2 ,),
284
- (2 ,),
303
+ np .array ([0.1 , 0.5 , 0.4 ], dtype = pytensor .config .floatX ),
304
+ (4 ,),
305
+ (),
306
+ (slice (None ),),
307
+ 1 ,
308
+ ),
309
+ # Matrix components, scalar index along first axis
310
+ (
311
+ (
312
+ np .array (0 , dtype = pytensor .config .floatX ),
313
+ np .array (1 , dtype = pytensor .config .floatX ),
314
+ ),
315
+ (
316
+ np .array (0.5 , dtype = pytensor .config .floatX ),
317
+ np .array (0.5 , dtype = pytensor .config .floatX ),
318
+ ),
319
+ (
320
+ np .array (100 , dtype = pytensor .config .floatX ),
321
+ np .array (1 , dtype = pytensor .config .floatX ),
322
+ ),
323
+ np .array ([0.1 , 0.5 , 0.4 ], dtype = pytensor .config .floatX ),
324
+ (2 , 3 ),
325
+ (),
285
326
(),
286
327
0 ,
287
328
),
329
+ # Scalar mixture components, vector index along first axis
288
330
(
289
331
(
290
- np .array ([ 0 , - 100 ] , dtype = pytensor .config .floatX ),
332
+ np .array (0 , dtype = pytensor .config .floatX ),
291
333
np .array (1 , dtype = pytensor .config .floatX ),
292
334
),
293
335
(
294
- np .array ([ 0.5 , 1 ] , dtype = pytensor .config .floatX ),
295
- np .array ([ 0.5 , 1 ] , dtype = pytensor .config .floatX ),
336
+ np .array (0.5 , dtype = pytensor .config .floatX ),
337
+ np .array (0.5 , dtype = pytensor .config .floatX ),
296
338
),
297
339
(
298
- np .array ([ 100 , 1000 ] , dtype = pytensor .config .floatX ),
340
+ np .array (100 , dtype = pytensor .config .floatX ),
299
341
np .array (1 , dtype = pytensor .config .floatX ),
300
342
),
301
- np .array ([[ 0.1 , 0.5 , 0.4 ], [ 0.4 , 0.1 , 0.5 ] ], dtype = pytensor .config .floatX ),
302
- None ,
303
- None ,
343
+ np .array ([0.1 , 0.5 , 0.4 ], dtype = pytensor .config .floatX ),
344
+ () ,
345
+ ( 6 ,) ,
304
346
(),
305
347
0 ,
306
348
),
349
+ # Vector mixture components, vector index along first axis
307
350
(
308
351
(
309
352
np .array (0 , dtype = pytensor .config .floatX ),
@@ -320,10 +363,31 @@ def test_hetero_mixture_binomial(p_val, size):
320
363
np .array ([0.1 , 0.5 , 0.4 ], dtype = pytensor .config .floatX ),
321
364
(2 ,),
322
365
(2 ,),
323
- (),
366
+ (slice ( None ), ),
324
367
0 ,
325
368
),
326
- # Same as before but with degenerate vector parameters
369
+ # Vector mixture components, vector index along last axis
370
+ pytest .param (
371
+ (
372
+ np .array (0 , dtype = pytensor .config .floatX ),
373
+ np .array (1 , dtype = pytensor .config .floatX ),
374
+ ),
375
+ (
376
+ np .array (0.5 , dtype = pytensor .config .floatX ),
377
+ np .array (0.5 , dtype = pytensor .config .floatX ),
378
+ ),
379
+ (
380
+ np .array (100 , dtype = pytensor .config .floatX ),
381
+ np .array (1 , dtype = pytensor .config .floatX ),
382
+ ),
383
+ np .array ([0.1 , 0.5 , 0.4 ], dtype = pytensor .config .floatX ),
384
+ (2 ,),
385
+ (4 ,),
386
+ (slice (None ),),
387
+ 1 ,
388
+ marks = pytest .mark .xfail (IndexError , reason = "Bug in AdvancedIndex Mixture logprob" ),
389
+ ),
390
+ # Vector mixture components (with degenerate vector parameters), vector index along first axis
327
391
(
328
392
(
329
393
np .array ([0 ], dtype = pytensor .config .floatX ),
@@ -343,45 +407,48 @@ def test_hetero_mixture_binomial(p_val, size):
343
407
(),
344
408
0 ,
345
409
),
410
+ # Vector mixture components (with vector parameters), vector index along first axis
346
411
(
347
412
(
348
- np .array (0 , dtype = pytensor .config .floatX ),
413
+ np .array ([ 0 , - 100 ] , dtype = pytensor .config .floatX ),
349
414
np .array (1 , dtype = pytensor .config .floatX ),
350
415
),
351
416
(
352
- np .array (0.5 , dtype = pytensor .config .floatX ),
353
- np .array (0.5 , dtype = pytensor .config .floatX ),
417
+ np .array ([ 0.5 , 1 ] , dtype = pytensor .config .floatX ),
418
+ np .array ([ 0.5 , 1 ] , dtype = pytensor .config .floatX ),
354
419
),
355
420
(
356
- np .array (100 , dtype = pytensor .config .floatX ),
421
+ np .array ([ 100 , 1000 ] , dtype = pytensor .config .floatX ),
357
422
np .array (1 , dtype = pytensor .config .floatX ),
358
423
),
359
- np .array ([0.1 , 0.5 , 0.4 ], dtype = pytensor .config .floatX ),
360
- (2 , 3 ),
361
- (2 , 3 ),
424
+ np .array ([[ 0.1 , 0.5 , 0.4 ], [ 0.4 , 0.1 , 0.5 ] ], dtype = pytensor .config .floatX ),
425
+ (2 ,),
426
+ (2 ,),
362
427
(),
363
428
0 ,
364
429
),
430
+ # Vector mixture components (with vector parameters), vector index along first axis, implicit sizes
365
431
(
366
432
(
367
- np .array (0 , dtype = pytensor .config .floatX ),
433
+ np .array ([ 0 , - 100 ] , dtype = pytensor .config .floatX ),
368
434
np .array (1 , dtype = pytensor .config .floatX ),
369
435
),
370
436
(
371
- np .array (0.5 , dtype = pytensor .config .floatX ),
372
- np .array (0.5 , dtype = pytensor .config .floatX ),
437
+ np .array ([ 0.5 , 1 ] , dtype = pytensor .config .floatX ),
438
+ np .array ([ 0.5 , 1 ] , dtype = pytensor .config .floatX ),
373
439
),
374
440
(
375
- np .array (100 , dtype = pytensor .config .floatX ),
441
+ np .array ([ 100 , 1000 ] , dtype = pytensor .config .floatX ),
376
442
np .array (1 , dtype = pytensor .config .floatX ),
377
443
),
378
- np .array ([0.1 , 0.5 , 0.4 ], dtype = pytensor .config .floatX ),
379
- ( 2 , 3 ) ,
380
- () ,
444
+ np .array ([[ 0.1 , 0.5 , 0.4 ], [ 0.4 , 0.1 , 0.5 ] ], dtype = pytensor .config .floatX ),
445
+ None ,
446
+ None ,
381
447
(),
382
448
0 ,
383
449
),
384
- pytest .param (
450
+ # Matrix mixture components, matrix index
451
+ (
385
452
(
386
453
np .array (0 , dtype = pytensor .config .floatX ),
387
454
np .array (1 , dtype = pytensor .config .floatX ),
@@ -395,12 +462,12 @@ def test_hetero_mixture_binomial(p_val, size):
395
462
np .array (1 , dtype = pytensor .config .floatX ),
396
463
),
397
464
np .array ([0.1 , 0.5 , 0.4 ], dtype = pytensor .config .floatX ),
398
- (3 ,),
399
- (3 ,),
400
- (slice (None ),),
401
- 1 ,
402
- marks = pytest .mark .xfail (IndexError , reason = "Bug in AdvancedIndex Mixture logprob" ),
465
+ (2 , 3 ),
466
+ (2 , 3 ),
467
+ (),
468
+ 0 ,
403
469
),
470
+ # Vector components, matrix indexing (constant along first dimension, then random)
404
471
(
405
472
(
406
473
np .array (0 , dtype = pytensor .config .floatX ),
@@ -420,6 +487,7 @@ def test_hetero_mixture_binomial(p_val, size):
420
487
(np .arange (5 ),),
421
488
0 ,
422
489
),
490
+ # Vector mixture components, tensor3 indexing (constant along first dimension, then degenerate, then random)
423
491
(
424
492
(
425
493
np .array (0 , dtype = pytensor .config .floatX ),
0 commit comments