Skip to content

Commit 6b410f3

Browse files
committed
check for variable having inputvars
1 parent 56ad68b commit 6b410f3

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

pymc/pytensorf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray:
179179

180180
from pymc.logprob.utils import rvs_in_graph
181181

182-
if not rvs_in_graph(x):
182+
if not inputvars(x) and not rvs_in_graph(x):
183183
cheap_eval_mode = Mode(linker="py", optimizer=None)
184184
return x.eval(mode=cheap_eval_mode)
185185

tests/test_pytensorf.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,13 @@ def test_pytensor_operations(self):
207207
assert isinstance(res, np.ndarray)
208208
np.testing.assert_array_equal(res, np.array([4, 7, 10]))
209209

210+
def test_pytensor_operations_raises(self):
211+
x = pt.scalar("x")
212+
target = 1 + 3 * x
213+
214+
with pytest.raises(TypeError, match="Data cannot be extracted from"):
215+
extract_obs_data(target)
216+
210217

211218
@pytest.mark.parametrize("input_dtype", ["int32", "int64", "float32", "float64"])
212219
def test_convert_data(input_dtype):

0 commit comments

Comments
 (0)