65
65
from pymc .logprob .rewriting import cleanup_ir , construct_ir_fgraph
66
66
from pymc .logprob .transform_value import TransformValuesRewrite
67
67
from pymc .logprob .transforms import Transform
68
- from pymc .logprob .utils import find_rvs_in_graph , rvs_to_value_vars
68
+ from pymc .logprob .utils import rvs_in_graph
69
+ from pymc .pytensorf import replace_vars_in_graphs
69
70
70
71
TensorLike : TypeAlias = Union [Variable , float , np .ndarray ]
71
72
@@ -76,7 +77,7 @@ def _find_unallowed_rvs_in_graph(graph):
76
77
77
78
return {
78
79
rv
79
- for rv in find_rvs_in_graph (graph )
80
+ for rv in rvs_in_graph (graph )
80
81
if not isinstance (rv .owner .op , (SimulatorRV , MinibatchIndexRV ))
81
82
}
82
83
@@ -530,11 +531,9 @@ def conditional_logp(
530
531
continue
531
532
532
533
# Replace `RandomVariable`s in the inputs with value variables.
533
- # Also, store the results in the `replacements` map for the nodes
534
- # that follow.
535
- remapped_vars , _ = rvs_to_value_vars (
536
- q_values + list (node .inputs ),
537
- initial_replacements = replacements ,
534
+ remapped_vars = replace_vars_in_graphs (
535
+ graphs = q_values + list (node .inputs ),
536
+ replacements = replacements ,
538
537
)
539
538
q_values = remapped_vars [: len (q_values )]
540
539
q_rv_inputs = remapped_vars [len (q_values ) :]
@@ -562,8 +561,7 @@ def conditional_logp(
562
561
563
562
logprob_vars [q_value_var ] = q_logprob_var
564
563
565
- # Recompute test values for the changes introduced by the
566
- # replacements above.
564
+ # Recompute test values for the changes introduced by the replacements above.
567
565
if config .compute_test_value != "off" :
568
566
for node in io_toposort (graph_inputs (q_logprob_vars ), q_logprob_vars ):
569
567
compute_test_value (node )
0 commit comments