Skip to content

Commit c806374

Browse files
committed
Fix bug in IfElse Mixture logprob
It did not account for the extra dimension added by the stacking operation, resulting in a logp call to an expanded RV with the original unexpanded value
1 parent 6bd4f34 commit c806374

File tree

2 files changed

+43
-4
lines changed

2 files changed

+43
-4
lines changed

pymc/logprob/mixture.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,13 +443,23 @@ def logprob_MixtureRV(
443443
logp_val = at.set_subtensor(logp_val[idx_m_on_axis], logp_m)
444444

445445
else:
446+
# If the stacking operation expands the component RVs, we have
447+
# to expand the value and later squeeze the logprob for everything
448+
# to work correctly
449+
join_axis_val = None if isinstance(join_axis.type, NoneTypeT) else join_axis.data
450+
451+
if join_axis_val is not None:
452+
value = at.expand_dims(value, axis=join_axis_val)
453+
446454
logp_val = 0.0
447455
for i, comp_rv in enumerate(comp_rvs):
448456
comp_logp = logprob(comp_rv, value)
457+
if join_axis_val is not None:
458+
comp_logp = at.squeeze(comp_logp, axis=join_axis_val)
449459
logp_val += ifelse(
450460
at.eq(indices[0], i),
451461
comp_logp,
452-
at.zeros_like(value),
462+
at.zeros_like(comp_logp),
453463
)
454464

455465
return logp_val

pymc/tests/logprob/test_mixture.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,26 @@ def test_hetero_mixture_binomial(p_val, size):
226226
(),
227227
0,
228228
),
229+
# Degenerate vector mixture components, scalar index
230+
(
231+
(
232+
np.array([0], dtype=pytensor.config.floatX),
233+
np.array(1, dtype=pytensor.config.floatX),
234+
),
235+
(
236+
np.array([0.5], dtype=pytensor.config.floatX),
237+
np.array(0.5, dtype=pytensor.config.floatX),
238+
),
239+
(
240+
np.array([100], dtype=pytensor.config.floatX),
241+
np.array(1, dtype=pytensor.config.floatX),
242+
),
243+
np.array([0.1, 0.5, 0.4], dtype=pytensor.config.floatX),
244+
None,
245+
(),
246+
(),
247+
0,
248+
),
229249
# Scalar mixture components, vector index
230250
(
231251
(
@@ -443,16 +463,25 @@ def test_hetero_mixture_categorical(
443463
gamma_sp = sp.gamma(Y_args[0], scale=1 / Y_args[1])
444464
norm_2_sp = sp.norm(loc=Z_args[0], scale=Z_args[1])
445465

466+
# Handle scipy annoying squeeze of random draws
467+
real_comp_size = tuple(X_rv.shape.eval())
468+
446469
for i in range(10):
447470
i_val = CategoricalRV.rng_fn(test_val_rng, p_val, idx_size)
448471

449472
indices_val = list(extra_indices)
450473
indices_val.insert(join_axis, i_val)
451474
indices_val = tuple(indices_val)
452475

453-
x_val = norm_1_sp.rvs(size=comp_size, random_state=test_val_rng)
454-
y_val = gamma_sp.rvs(size=comp_size, random_state=test_val_rng)
455-
z_val = norm_2_sp.rvs(size=comp_size, random_state=test_val_rng)
476+
x_val = np.broadcast_to(
477+
norm_1_sp.rvs(size=comp_size, random_state=test_val_rng), real_comp_size
478+
)
479+
y_val = np.broadcast_to(
480+
gamma_sp.rvs(size=comp_size, random_state=test_val_rng), real_comp_size
481+
)
482+
z_val = np.broadcast_to(
483+
norm_2_sp.rvs(size=comp_size, random_state=test_val_rng), real_comp_size
484+
)
456485

457486
component_logps = np.stack(
458487
[norm_1_sp.logpdf(x_val), gamma_sp.logpdf(y_val), norm_2_sp.logpdf(z_val)],

0 commit comments

Comments
 (0)