Skip to content

Commit 7929b71

Browse files
authored
Fix bug in mixture logprob inference
1 parent 4220ed8 commit 7929b71

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

pymc/logprob/mixture.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
is_basic_idx,
6363
)
6464
from pytensor.tensor.type import TensorType
65-
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceConstant, SliceType
65+
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceType
6666
from pytensor.tensor.variable import TensorVariable
6767

6868
from pymc.logprob.abstract import (
@@ -289,9 +289,10 @@ def find_measurable_index_mixture(fgraph, node):
289289
# We don't support (non-scalar) integer array indexing as it can pick repeated values,
290290
# but the Mixture logprob assumes all mixture values are independent
291291
if any(
292-
indices.dtype.startswith("int") and sum(1 - b for b in indices.type.broadcastable) > 0
292+
isinstance(indices, TensorVariable)
293+
and indices.dtype.startswith("int")
294+
and not all(indices.type.broadcastable)
293295
for indices in mixing_indices
294-
if not isinstance(indices, SliceConstant)
295296
):
296297
return None
297298

tests/logprob/test_mixture.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,3 +1156,30 @@ def test_nested_ifelse():
11561156
np.testing.assert_almost_equal(mix_logp_fn(0, test_value), sp.norm.logpdf(test_value, -5, 1))
11571157
np.testing.assert_almost_equal(mix_logp_fn(1, test_value), sp.norm.logpdf(test_value, 0, 1))
11581158
np.testing.assert_almost_equal(mix_logp_fn(2, test_value), sp.norm.logpdf(test_value, 5, 1))
1159+
1160+
1161+
def test_advanced_subtensor_none_and_integer():
1162+
"""
1163+
Test for correct error handling when the logp graph is over-specified.
1164+
1165+
Providing values for both a random variable ('a') and its deterministic
1166+
child ('b') creates a logical conflict. The system should detect this
1167+
and raise a controlled RuntimeError.
1168+
1169+
This test fails if the rewriter instead crashes with the old internal
1170+
AttributeError bug, which would indicate a regression. Please see: #7762
1171+
"""
1172+
a = pt.random.normal(0, 1, size=(10,), name="a")
1173+
inds = np.array([0, 1, 2, 3], dtype="int32")
1174+
b = a[None, inds]
1175+
1176+
b_val = b.type()
1177+
b_val.name = "b_val"
1178+
a_val = a.type()
1179+
a_val.name = "a_val"
1180+
1181+
with pytest.raises(
1182+
RuntimeError,
1183+
match="logprob terms of the following value variables could not be derived: {b_val}",
1184+
):
1185+
conditional_logp({b: b_val, a: a_val})

0 commit comments

Comments
 (0)