1818import pytensor
1919import pytensor .tensor as pt
2020from pymc .logprob .utils import get_underlying_scalar_constant_value
21+ from pymc .pytensorf import find_rng_nodes
2122from pytensor .graph import FunctionGraph
2223from pytensor .graph .basic import clone_get_equiv
2324from pytensor .tensor .random .basic import NormalRV
3031 from pytensor .graph .opt import local_optimizer as node_rewriter
3132
3233
33- def find_rng_nodes (outputs ):
34- """
35- Find all RNG (random number generator) shared variables in the graph.
36-
37- In PyMC v5/PyTensor, this replaces pymc.pytensorf.find_rng_nodes from PyMC v4.
38- """
39- from pytensor .graph import graph_inputs
40- from pytensor .tensor .random .var import (
41- RandomGeneratorSharedVariable ,
42- RandomStateSharedVariable ,
43- )
44-
45- # Get all inputs to the graph
46- inputs = graph_inputs (outputs )
47-
48- # Filter for RNG shared variables
49- rng_nodes = [
50- node
51- for node in inputs
52- if isinstance (node , (RandomStateSharedVariable , RandomGeneratorSharedVariable ))
53- ]
54-
55- return rng_nodes
56-
57-
5834def resampled_as_non_centered (outputs , resampled_vars , free_RVs ):
5935 # We make a shallow copy of the lists so that we don't make alter the original ones
6036 resampled_vars = [[rv for rv in rvs ] for rvs in resampled_vars ]
@@ -95,7 +71,6 @@ def make_normal_not_centered(fgraph, node):
9571 return [None , new_node ]
9672
9773 rewrite = in2out (make_normal_not_centered )
98-
9974 graph = FunctionGraph (outputs = outputs , clone = False )
10075 rewrite .apply (graph )
10176 return graph .outputs , free_RVs , resampled_vars
@@ -114,15 +89,9 @@ def clone_replace_rv_consistent(outputs, free_RVs, replace):
11489 # That way, the draws across the cloned and uncloned graph will be uncorrelated
11590 rng_nodes = find_rng_nodes (fg .outputs )
11691 new_rng_nodes : List [Union [np .random .RandomState , np .random .Generator ]] = []
117- from pytensor .tensor .random .var import RandomStateSharedVariable
11892
119- for rng_node in rng_nodes :
120- rng_cls : type
121- if isinstance (rng_node , RandomStateSharedVariable ):
122- rng_cls = np .random .RandomState
123- else :
124- rng_cls = np .random .Generator
125- new_rng_nodes .append (pytensor .shared (rng_cls (np .random .PCG64 ())))
93+ for _ in rng_nodes :
94+ new_rng_nodes .append (pytensor .shared (np .random .Generator (np .random .PCG64 ())))
12695 orig_replace = {clone_map [rv ]: rv for rv in free_RVs if rv in clone_map }
12796 orig_replace .update (dict (zip (rng_nodes , new_rng_nodes )))
12897 # replace_var can only be constant values or shared, not graph that depend on nodes
0 commit comments