39
39
import pytensor .tensor as pt
40
40
41
41
from pytensor .graph .rewriting .basic import node_rewriter
42
- from pytensor .raise_op import CheckAndRaise , ExceptionType
42
+ from pytensor .raise_op import CheckAndRaise
43
43
from pytensor .tensor .shape import SpecifyShape
44
44
45
45
from pymc .logprob .abstract import MeasurableVariable , _logprob , _logprob_helper
@@ -106,15 +106,18 @@ class MeasurableCheckAndRaise(CheckAndRaise):
106
106
107
107
108
108
@_logprob .register (MeasurableCheckAndRaise )
109
- def logprob_assert (op , values , inner_rv , * assertion , ** kwargs ):
109
+ def logprob_check_and_raise (op , values , inner_rv , * assertions , ** kwargs ):
110
+ from pymc .pytensorf import replace_rvs_by_values
111
+
110
112
(value ,) = values
111
113
# transfer assertion from rv to value
112
- value = op (assertion , value )
114
+ assertions = replace_rvs_by_values (assertions , rvs_to_values = {inner_rv : value })
115
+ value = op (value , * assertions )
113
116
return _logprob_helper (inner_rv , value )
114
117
115
118
116
119
@node_rewriter ([CheckAndRaise ])
117
- def find_measurable_asserts (fgraph , node ) -> Optional [List [MeasurableCheckAndRaise ]]:
120
+ def find_measurable_check_and_raise (fgraph , node ) -> Optional [List [MeasurableCheckAndRaise ]]:
118
121
r"""Finds `AssertOp`\s for which a `logprob` can be computed."""
119
122
120
123
if isinstance (node .op , MeasurableCheckAndRaise ):
@@ -126,24 +129,19 @@ def find_measurable_asserts(fgraph, node) -> Optional[List[MeasurableCheckAndRai
126
129
return None # pragma: no cover
127
130
128
131
base_rv , * conds = node .inputs
132
+ if not rv_map_feature .request_measurable ([base_rv ]):
133
+ return None
129
134
130
- if not (
131
- base_rv .owner
132
- and isinstance (base_rv .owner .op , MeasurableVariable )
133
- and base_rv not in rv_map_feature .rv_values
134
- ):
135
- return None # pragma: no cover
136
-
137
- exception_type = ExceptionType ()
138
- new_op = MeasurableCheckAndRaise (exc_type = exception_type )
135
+ op = node .op
136
+ new_op = MeasurableCheckAndRaise (exc_type = op .exc_type , msg = op .msg )
139
137
new_rv = new_op .make_node (base_rv , * conds ).default_output ()
140
138
141
139
return [new_rv ]
142
140
143
141
144
142
measurable_ir_rewrites_db .register (
145
- "find_measurable_asserts " ,
146
- find_measurable_asserts ,
143
+ "find_measurable_check_and_raise " ,
144
+ find_measurable_check_and_raise ,
147
145
"basic" ,
148
146
"assert" ,
149
147
)
0 commit comments