File tree Expand file tree Collapse file tree 1 file changed +15
-0
lines changed
Expand file tree Collapse file tree 1 file changed +15
-0
lines changed Original file line number Diff line number Diff line change 3333from pytensor .scalar import Cast
3434from pytensor .tensor .elemwise import Elemwise
3535from pytensor .tensor .random .basic import IntegersRV
36+ from pytensor .tensor .random .var import RandomGeneratorSharedVariable
3637from pytensor .tensor .type import TensorType
3738from pytensor .tensor .variable import TensorConstant , TensorVariable
3839
@@ -148,6 +149,19 @@ def __str__(self):
148149 return "Minibatch"
149150
150151
152+ def first_inputs (r ):
153+ if not r .owner :
154+ return
155+
156+ first_input = r .owner .inputs [0 ]
157+ yield first_input
158+ yield from first_inputs (first_input )
159+
160+
161+ def has_random_ancestor (r ):
162+ return any (isinstance (i , RandomGeneratorSharedVariable ) for i in first_inputs (r ))
163+
164+
151165def is_valid_observed (v ) -> bool :
152166 if not isinstance (v , Variable ):
153167 # Non-symbolic constant
@@ -165,6 +179,7 @@ def is_valid_observed(v) -> bool:
165179 and isinstance (v .owner .op .scalar_op , Cast )
166180 and is_valid_observed (v .owner .inputs [0 ])
167181 )
182+ or not has_random_ancestor (v )
168183 # Or Minibatch
169184 or (
170185 isinstance (v .owner .op , MinibatchOp )
You can’t perform that action at this time.
0 commit comments