40
40
from typing import Dict , List , Optional , Sequence , Union
41
41
42
42
import numpy as np
43
- import pytensor
44
43
import pytensor .tensor as pt
45
44
46
45
from pytensor import config
53
52
)
54
53
from pytensor .graph .op import compute_test_value
55
54
from pytensor .graph .rewriting .basic import GraphRewriter , NodeRewriter
56
- from pytensor .tensor .random .op import RandomVariable
57
55
from pytensor .tensor .var import TensorVariable
58
56
from typing_extensions import TypeAlias
59
57
66
64
)
67
65
from pymc .logprob .rewriting import cleanup_ir , construct_ir_fgraph
68
66
from pymc .logprob .transforms import RVTransform , TransformValuesRewrite
69
- from pymc .logprob .utils import rvs_to_value_vars
67
+ from pymc .logprob .utils import find_rvs_in_graph , rvs_to_value_vars
70
68
71
69
TensorLike : TypeAlias = Union [Variable , float , np .ndarray ]
72
70
73
71
72
+ def _find_unallowed_rvs_in_graph (graph ):
73
+ from pymc .data import MinibatchIndexRV
74
+ from pymc .distributions .simulator import SimulatorRV
75
+
76
+ return {
77
+ rv
78
+ for rv in find_rvs_in_graph (graph )
79
+ if not isinstance (rv .owner .op , (SimulatorRV , MinibatchIndexRV ))
80
+ }
81
+
82
+
74
83
def _warn_rvs_in_inferred_graph (graph : Union [TensorVariable , Sequence [TensorVariable ]]):
75
84
"""Issue warning if any RVs are found in graph.
76
85
@@ -81,13 +90,11 @@ def _warn_rvs_in_inferred_graph(graph: Union[TensorVariable, Sequence[TensorVari
81
90
This makes it impossible (or difficult) to replace it by the respective values afterward,
82
91
so we instruct users to do it beforehand.
83
92
"""
84
- from pymc .testing import assert_no_rvs
85
93
86
- try :
87
- assert_no_rvs (graph )
88
- except AssertionError :
94
+ rvs_in_graph = _find_unallowed_rvs_in_graph (graph )
95
+ if rvs_in_graph :
89
96
warnings .warn (
90
- "RandomVariables were found in the derived graph. "
97
+ f "RandomVariables { rvs_in_graph } were found in the derived graph. "
91
98
"These variables are a clone and do not match the original ones on identity.\n "
92
99
"If you are deriving a quantity that depends on model RVs, use `model.replace_rvs_by_values` first. For example: "
93
100
"`logp(model.replace_rvs_by_values([rv])[0], value)`" ,
@@ -149,6 +156,13 @@ def icdf(
149
156
return expr
150
157
151
158
159
+ RVS_IN_JOINT_LOGP_GRAPH_MSG = (
160
+ "Random variables detected in the logp graph: %s.\n "
161
+ "This can happen when DensityDist logp or Interval transform functions reference nonlocal variables,\n "
162
+ "or when not all rvs have a corresponding value variable."
163
+ )
164
+
165
+
152
166
def factorized_joint_logprob (
153
167
rv_values : Dict [TensorVariable , TensorVariable ],
154
168
warn_missing_rvs : bool = True ,
@@ -316,34 +330,13 @@ def factorized_joint_logprob(
316
330
cleanup_ir (logprob_expressions )
317
331
318
332
if warn_missing_rvs :
319
- _warn_rvs_in_inferred_graph (logprob_expressions )
333
+ rvs_in_logp_expressions = _find_unallowed_rvs_in_graph (logprob_expressions )
334
+ if rvs_in_logp_expressions :
335
+ warnings .warn (RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions , UserWarning )
320
336
321
337
return logprob_vars
322
338
323
339
324
- def _check_no_rvs (logp_terms : Sequence [TensorVariable ]):
325
- # Raise if there are unexpected RandomVariables in the logp graph
326
- # Only SimulatorRVs MinibatchIndexRVs are allowed
327
- from pymc .data import MinibatchIndexRV
328
- from pymc .distributions .simulator import SimulatorRV
329
-
330
- unexpected_rv_nodes = [
331
- node
332
- for node in pytensor .graph .ancestors (logp_terms )
333
- if (
334
- node .owner
335
- and isinstance (node .owner .op , RandomVariable )
336
- and not isinstance (node .owner .op , (SimulatorRV , MinibatchIndexRV ))
337
- )
338
- ]
339
- if unexpected_rv_nodes :
340
- raise ValueError (
341
- f"Random variables detected in the logp graph: { unexpected_rv_nodes } .\n "
342
- "This can happen when DensityDist logp or Interval transform functions "
343
- "reference nonlocal variables."
344
- )
345
-
346
-
347
340
def joint_logp (
348
341
rvs : Sequence [TensorVariable ],
349
342
* ,
@@ -381,5 +374,10 @@ def joint_logp(
381
374
value_var = rvs_to_values [rv ]
382
375
logp_terms [value_var ] = temp_logp_terms [value_var ]
383
376
384
- _check_no_rvs (list (logp_terms .values ()))
385
- return list (logp_terms .values ())
377
+ logp_terms_list = list (logp_terms .values ())
378
+
379
+ rvs_in_logp_expressions = _find_unallowed_rvs_in_graph (logp_terms_list )
380
+ if rvs_in_logp_expressions :
381
+ raise ValueError (RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions )
382
+
383
+ return logp_terms_list
0 commit comments