Skip to content

Commit 9bba026

Browse files
ricardoV94twiecki
authored andcommitted
Fix logprob of check_and_raise
1 parent 7ed5b71 commit 9bba026

File tree

3 files changed

+19
-18
lines changed

3 files changed

+19
-18
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ jobs:
105105
tests/logprob/test_abstract.py
106106
tests/logprob/test_basic.py
107107
tests/logprob/test_binary.py
108+
tests/logprob/test_checks.py
108109
tests/logprob/test_censoring.py
109110
tests/logprob/test_composite_logprob.py
110111
tests/logprob/test_cumsum.py

pymc/logprob/checks.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
import pytensor.tensor as pt
4040

4141
from pytensor.graph.rewriting.basic import node_rewriter
42-
from pytensor.raise_op import CheckAndRaise, ExceptionType
42+
from pytensor.raise_op import CheckAndRaise
4343
from pytensor.tensor.shape import SpecifyShape
4444

4545
from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
@@ -106,15 +106,18 @@ class MeasurableCheckAndRaise(CheckAndRaise):
106106

107107

108108
@_logprob.register(MeasurableCheckAndRaise)
109-
def logprob_assert(op, values, inner_rv, *assertion, **kwargs):
109+
def logprob_check_and_raise(op, values, inner_rv, *assertions, **kwargs):
110+
from pymc.pytensorf import replace_rvs_by_values
111+
110112
(value,) = values
111113
# transfer assertion from rv to value
112-
value = op(assertion, value)
114+
assertions = replace_rvs_by_values(assertions, rvs_to_values={inner_rv: value})
115+
value = op(value, *assertions)
113116
return _logprob_helper(inner_rv, value)
114117

115118

116119
@node_rewriter([CheckAndRaise])
117-
def find_measurable_asserts(fgraph, node) -> Optional[List[MeasurableCheckAndRaise]]:
120+
def find_measurable_check_and_raise(fgraph, node) -> Optional[List[MeasurableCheckAndRaise]]:
118121
r"""Finds `AssertOp`\s for which a `logprob` can be computed."""
119122

120123
if isinstance(node.op, MeasurableCheckAndRaise):
@@ -126,24 +129,19 @@ def find_measurable_asserts(fgraph, node) -> Optional[List[MeasurableCheckAndRai
126129
return None # pragma: no cover
127130

128131
base_rv, *conds = node.inputs
132+
if not rv_map_feature.request_measurable([base_rv]):
133+
return None
129134

130-
if not (
131-
base_rv.owner
132-
and isinstance(base_rv.owner.op, MeasurableVariable)
133-
and base_rv not in rv_map_feature.rv_values
134-
):
135-
return None # pragma: no cover
136-
137-
exception_type = ExceptionType()
138-
new_op = MeasurableCheckAndRaise(exc_type=exception_type)
135+
op = node.op
136+
new_op = MeasurableCheckAndRaise(exc_type=op.exc_type, msg=op.msg)
139137
new_rv = new_op.make_node(base_rv, *conds).default_output()
140138

141139
return [new_rv]
142140

143141

144142
measurable_ir_rewrites_db.register(
145-
"find_measurable_asserts",
146-
find_measurable_asserts,
143+
"find_measurable_check_and_raise",
144+
find_measurable_check_and_raise,
147145
"basic",
148146
"assert",
149147
)

tests/logprob/test_checks.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_assert_logprob():
8181
rv = pt.random.normal()
8282
assert_op = Assert("Test assert")
8383
# Example: Add assert that rv must be positive
84-
assert_rv = assert_op(rv > 0, rv)
84+
assert_rv = assert_op(rv, rv > 0)
8585
assert_rv.name = "assert_rv"
8686

8787
assert_vv = assert_rv.clone()
@@ -90,8 +90,10 @@ def test_assert_logprob():
9090
# Check valid value is correct and doesn't raise
9191
# Since here the value to the rv satisfies the condition, no error is raised.
9292
valid_value = 3.0
93-
with pytest.raises(AssertionError, match="Test assert"):
94-
assert_logp.eval({assert_vv: valid_value})
93+
np.testing.assert_allclose(
94+
assert_logp.eval({assert_vv: valid_value}),
95+
stats.norm.logpdf(valid_value),
96+
)
9597

9698
# Check invalid value
9799
# Since here the value to the rv is negative, an exception is raised as the condition is not met

0 commit comments

Comments
 (0)