3030from pytensor .compile .sharedvalue import SharedVariable
3131from pytensor .graph .basic import Variable
3232from pytensor .raise_op import Assert
33- from pytensor .scalar import Cast
34- from pytensor .tensor .elemwise import Elemwise
3533from pytensor .tensor .random .basic import IntegersRV
36- from pytensor .tensor .random .var import RandomGeneratorSharedVariable
3734from pytensor .tensor .type import TensorType
3835from pytensor .tensor .variable import TensorConstant , TensorVariable
3936
4037import pymc as pm
4138
39+ from pymc .logprob .utils import rvs_in_graph
4240from pymc .pytensorf import GeneratorOp , convert_data , smarttypeX
4341from pymc .vartypes import isgenerator
4442
@@ -149,22 +147,8 @@ def __str__(self):
149147 return "Minibatch"
150148
151149
152- def first_inputs (r ):
153- if not r .owner :
154- return
155-
156- inputs = r .owner .inputs
157-
158- if not inputs :
159- return
160-
161- first_input = inputs [0 ]
162- yield first_input
163- yield from first_inputs (first_input )
164-
165-
166150def has_random_ancestor (r ):
167- return any ( isinstance ( i , RandomGeneratorSharedVariable ) for i in first_inputs ( r ) )
151+ return rvs_in_graph ( r ) != set ( )
168152
169153
170154def is_valid_observed (v ) -> bool :
@@ -177,14 +161,7 @@ def is_valid_observed(v) -> bool:
177161 return True
178162
179163 return (
180- # The only PyTensor operation we allow on observed data is type casting
181- # Although we could allow for any graph that does not depend on other RVs
182- (
183- isinstance (v .owner .op , Elemwise )
184- and isinstance (v .owner .op .scalar_op , Cast )
185- and is_valid_observed (v .owner .inputs [0 ])
186- )
187- or not has_random_ancestor (v )
164+ not has_random_ancestor (v )
188165 # Or Minibatch
189166 or (
190167 isinstance (v .owner .op , MinibatchOp )
0 commit comments