Skip to content

Commit 7f71397

Browse files
committed
use existing function
1 parent b36e573 commit 7f71397

File tree

1 file changed

+3
-26
lines changed

1 file changed

+3
-26
lines changed

pymc/data.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,13 @@
3030
from pytensor.compile.sharedvalue import SharedVariable
3131
from pytensor.graph.basic import Variable
3232
from pytensor.raise_op import Assert
33-
from pytensor.scalar import Cast
34-
from pytensor.tensor.elemwise import Elemwise
3533
from pytensor.tensor.random.basic import IntegersRV
36-
from pytensor.tensor.random.var import RandomGeneratorSharedVariable
3734
from pytensor.tensor.type import TensorType
3835
from pytensor.tensor.variable import TensorConstant, TensorVariable
3936

4037
import pymc as pm
4138

39+
from pymc.logprob.utils import rvs_in_graph
4240
from pymc.pytensorf import GeneratorOp, convert_data, smarttypeX
4341
from pymc.vartypes import isgenerator
4442

@@ -149,22 +147,8 @@ def __str__(self):
149147
return "Minibatch"
150148

151149

152-
def first_inputs(r):
153-
if not r.owner:
154-
return
155-
156-
inputs = r.owner.inputs
157-
158-
if not inputs:
159-
return
160-
161-
first_input = inputs[0]
162-
yield first_input
163-
yield from first_inputs(first_input)
164-
165-
166150
def has_random_ancestor(r):
167-
return any(isinstance(i, RandomGeneratorSharedVariable) for i in first_inputs(r))
151+
return rvs_in_graph(r) != set()
168152

169153

170154
def is_valid_observed(v) -> bool:
@@ -177,14 +161,7 @@ def is_valid_observed(v) -> bool:
177161
return True
178162

179163
return (
180-
# The only PyTensor operation we allow on observed data is type casting
181-
# Although we could allow for any graph that does not depend on other RVs
182-
(
183-
isinstance(v.owner.op, Elemwise)
184-
and isinstance(v.owner.op.scalar_op, Cast)
185-
and is_valid_observed(v.owner.inputs[0])
186-
)
187-
or not has_random_ancestor(v)
164+
not has_random_ancestor(v)
188165
# Or Minibatch
189166
or (
190167
isinstance(v.owner.op, MinibatchOp)

0 commit comments

Comments
 (0)