From 44ffbd5812e9685cc95efd0f1bc9066c29b0d98a Mon Sep 17 00:00:00 2001 From: Asif Zubair Date: Sun, 27 Jul 2025 15:45:29 -0500 Subject: [PATCH 1/5] Fix for #7762: add guard and test --- pymc/logprob/mixture.py | 7 ++++--- tests/logprob/test_mixture.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index ce6a11d20..1680f256d 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -62,7 +62,7 @@ is_basic_idx, ) from pytensor.tensor.type import TensorType -from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceConstant, SliceType +from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceType from pytensor.tensor.variable import TensorVariable from pymc.logprob.abstract import ( @@ -289,9 +289,10 @@ def find_measurable_index_mixture(fgraph, node): # We don't support (non-scalar) integer array indexing as it can pick repeated values, # but the Mixture logprob assumes all mixture values are independent if any( - indices.dtype.startswith("int") and sum(1 - b for b in indices.type.broadcastable) > 0 + hasattr(indices, "dtype") + and indices.dtype.startswith("int") + and sum(1 - b for b in indices.type.broadcastable) > 0 for indices in mixing_indices - if not isinstance(indices, SliceConstant) ): return None diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index ffb2bf07c..c3afca492 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -1156,3 +1156,31 @@ def test_nested_ifelse(): np.testing.assert_almost_equal(mix_logp_fn(0, test_value), sp.norm.logpdf(test_value, -5, 1)) np.testing.assert_almost_equal(mix_logp_fn(1, test_value), sp.norm.logpdf(test_value, 0, 1)) np.testing.assert_almost_equal(mix_logp_fn(2, test_value), sp.norm.logpdf(test_value, 5, 1)) + + +def test_advanced_subtensor_none_and_integer(): + """ + Test for correct error handling when the logp graph is over-specified. + + Providing values for both a random variable ('a') and its deterministic + child ('b') creates a logical conflict. The system should detect this + and raise a controlled RuntimeError. + + This test fails if the rewriter instead crashes with the old internal + AttributeError bug, which would indicate a regression. Please see: #7762 + """ + a = pt.random.normal(0, 1, size=(10,), name="a") + inds = np.array([0, 1, 2, 3], dtype="int32") + b = a[None, inds] + + b_val = b.type() + b_val.name = "b_val" + a_val = a.type() + a_val.name = "a_val" + + with pytest.raises(RuntimeError) as e: + conditional_logp({b: b_val, a: a_val}) + + # Assert that the error message does NOT contain "AttributeError", + # which would indicate the presence of the original bug. + assert "AttributeError" not in str(e.value) From bbac70e10440aa27fa8e73e4936e50c224081e4a Mon Sep 17 00:00:00 2001 From: Asif Zubair Date: Mon, 28 Jul 2025 11:58:09 -0500 Subject: [PATCH 2/5] Fix for #7762: improve readability of broadcastable condition Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/logprob/mixture.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 1680f256d..2d56c9765 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -291,7 +291,7 @@ def find_measurable_index_mixture(fgraph, node): if any( hasattr(indices, "dtype") and indices.dtype.startswith("int") - and sum(1 - b for b in indices.type.broadcastable) > 0 + and not all(indices.type.broadcastable) for indices in mixing_indices ): return None From a70cfee03c693c6ab6651cbdb85e9a1e405f0ccf Mon Sep 17 00:00:00 2001 From: Asif Zubair Date: Mon, 28 Jul 2025 12:46:24 -0500 Subject: [PATCH 3/5] Fix for #7762: make guard explicit --- pymc/logprob/mixture.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 2d56c9765..62f2a8f40 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -289,7 +289,7 @@ def find_measurable_index_mixture(fgraph, node): # We don't support (non-scalar) integer array indexing as it can pick repeated values, # but the Mixture logprob assumes all mixture values are independent if any( - hasattr(indices, "dtype") + isinstance(indices, TensorVariable) and indices.dtype.startswith("int") and not all(indices.type.broadcastable) for indices in mixing_indices From 2102cb926e5a3dd3db3f6e4b7a19612cc5b55046 Mon Sep 17 00:00:00 2001 From: Asif Zubair Date: Mon, 28 Jul 2025 18:40:40 -0500 Subject: [PATCH 4/5] Fix for #7762: match on the error message --- tests/logprob/test_mixture.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index c3afca492..d3dc8b403 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -1178,9 +1178,5 @@ def test_advanced_subtensor_none_and_integer(): a_val = a.type() a_val.name = "a_val" - with pytest.raises(RuntimeError) as e: + with pytest.raises(RuntimeError, match="logprob terms of the following value variables could not be derived: {b_val}"): conditional_logp({b: b_val, a: a_val}) - - # Assert that the error message does NOT contain "AttributeError", - # which would indicate the presence of the original bug. - assert "AttributeError" not in str(e.value) From d25ce812f459f9ea1d146bc8570ab56bc6537104 Mon Sep 17 00:00:00 2001 From: Asif Zubair Date: Mon, 28 Jul 2025 18:55:18 -0500 Subject: [PATCH 5/5] Fix for #7762: lintr issues --- tests/logprob/test_mixture.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index d3dc8b403..eb9fc8148 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -1178,5 +1178,8 @@ def test_advanced_subtensor_none_and_integer(): a_val = a.type() a_val.name = "a_val" - with pytest.raises(RuntimeError, match="logprob terms of the following value variables could not be derived: {b_val}"): + with pytest.raises( + RuntimeError, + match="logprob terms of the following value variables could not be derived: {b_val}", + ): conditional_logp({b: b_val, a: a_val})