Skip to content

Commit 56ad68b

Browse files
committed
simplify and remove helper function
1 parent acd22f3 commit 56ad68b

File tree

3 files changed

+4
-8
lines changed

3 files changed

+4
-8
lines changed

pymc/data.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,6 @@ def __str__(self):
147147
return "Minibatch"
148148

149149

150-
def has_random_ancestor(r):
151-
return rvs_in_graph(r) != set()
152-
153-
154150
def is_valid_observed(v) -> bool:
155151
if not isinstance(v, Variable):
156152
# Non-symbolic constant
@@ -161,7 +157,7 @@ def is_valid_observed(v) -> bool:
161157
return True
162158

163159
return (
164-
not has_random_ancestor(v)
160+
not rvs_in_graph(v)
165161
# Or Minibatch
166162
or (
167163
isinstance(v.owner.op, MinibatchOp)

pymc/pytensorf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,9 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray:
177177
mask[mask_idx] = 1
178178
return np.ma.MaskedArray(array_data, mask)
179179

180-
from pymc.data import has_random_ancestor
180+
from pymc.logprob.utils import rvs_in_graph
181181

182-
if not has_random_ancestor(x):
182+
if not rvs_in_graph(x):
183183
cheap_eval_mode = Mode(linker="py", optimizer=None)
184184
return x.eval(mode=cheap_eval_mode)
185185

tests/test_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ def test_multiple_vars(self):
628628
def test_scaling_data_works_in_likelihood() -> None:
629629
data = np.array([10, 11, 12, 13, 14, 15])
630630

631-
with pm.Model() as model:
631+
with pm.Model():
632632
target = pm.Data("target", data)
633633
scale = 12
634634
scaled_target = target / scale

0 commit comments

Comments
 (0)