@@ -226,6 +226,26 @@ def test_hetero_mixture_binomial(p_val, size):
226
226
(),
227
227
0 ,
228
228
),
229
+ # Degenerate vector mixture components, scalar index
230
+ (
231
+ (
232
+ np .array ([0 ], dtype = pytensor .config .floatX ),
233
+ np .array (1 , dtype = pytensor .config .floatX ),
234
+ ),
235
+ (
236
+ np .array ([0.5 ], dtype = pytensor .config .floatX ),
237
+ np .array (0.5 , dtype = pytensor .config .floatX ),
238
+ ),
239
+ (
240
+ np .array ([100 ], dtype = pytensor .config .floatX ),
241
+ np .array (1 , dtype = pytensor .config .floatX ),
242
+ ),
243
+ np .array ([0.1 , 0.5 , 0.4 ], dtype = pytensor .config .floatX ),
244
+ None ,
245
+ (),
246
+ (),
247
+ 0 ,
248
+ ),
229
249
# Scalar mixture components, vector index
230
250
(
231
251
(
@@ -443,16 +463,25 @@ def test_hetero_mixture_categorical(
443
463
gamma_sp = sp .gamma (Y_args [0 ], scale = 1 / Y_args [1 ])
444
464
norm_2_sp = sp .norm (loc = Z_args [0 ], scale = Z_args [1 ])
445
465
466
+ # Handle scipy annoying squeeze of random draws
467
+ real_comp_size = tuple (X_rv .shape .eval ())
468
+
446
469
for i in range (10 ):
447
470
i_val = CategoricalRV .rng_fn (test_val_rng , p_val , idx_size )
448
471
449
472
indices_val = list (extra_indices )
450
473
indices_val .insert (join_axis , i_val )
451
474
indices_val = tuple (indices_val )
452
475
453
- x_val = norm_1_sp .rvs (size = comp_size , random_state = test_val_rng )
454
- y_val = gamma_sp .rvs (size = comp_size , random_state = test_val_rng )
455
- z_val = norm_2_sp .rvs (size = comp_size , random_state = test_val_rng )
476
+ x_val = np .broadcast_to (
477
+ norm_1_sp .rvs (size = comp_size , random_state = test_val_rng ), real_comp_size
478
+ )
479
+ y_val = np .broadcast_to (
480
+ gamma_sp .rvs (size = comp_size , random_state = test_val_rng ), real_comp_size
481
+ )
482
+ z_val = np .broadcast_to (
483
+ norm_2_sp .rvs (size = comp_size , random_state = test_val_rng ), real_comp_size
484
+ )
456
485
457
486
component_logps = np .stack (
458
487
[norm_1_sp .logpdf (x_val ), gamma_sp .logpdf (y_val ), norm_2_sp .logpdf (z_val )],
0 commit comments