Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,13 @@
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Variable
from pytensor.raise_op import Assert
from pytensor.scalar import Cast
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.basic import IntegersRV
from pytensor.tensor.type import TensorType
from pytensor.tensor.variable import TensorConstant, TensorVariable

import pymc as pm

from pymc.logprob.utils import rvs_in_graph
from pymc.pytensorf import GeneratorOp, convert_data, smarttypeX
from pymc.vartypes import isgenerator

Expand Down Expand Up @@ -148,6 +147,10 @@ def __str__(self):
return "Minibatch"


def has_random_ancestor(r):
return rvs_in_graph(r) != set()


def is_valid_observed(v) -> bool:
if not isinstance(v, Variable):
# Non-symbolic constant
Expand All @@ -158,13 +161,7 @@ def is_valid_observed(v) -> bool:
return True

return (
# The only PyTensor operation we allow on observed data is type casting
# Although we could allow for any graph that does not depend on other RVs
(
isinstance(v.owner.op, Elemwise)
and isinstance(v.owner.op.scalar_op, Cast)
and is_valid_observed(v.owner.inputs[0])
)
not has_random_ancestor(v)
# Or Minibatch
or (
isinstance(v.owner.op, MinibatchOp)
Expand Down
5 changes: 5 additions & 0 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray:
mask[mask_idx] = 1
return np.ma.MaskedArray(array_data, mask)

from pymc.data import has_random_ancestor

if not has_random_ancestor(x):
return x.eval()

raise TypeError(f"Data cannot be extracted from {x}")


Expand Down
11 changes: 11 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,3 +623,14 @@ def test_multiple_vars(self):
[draw_mA, draw_mB] = pm.draw([mA, mB])
assert draw_mA.shape == (10,)
np.testing.assert_allclose(draw_mA, -draw_mB)


def test_scaling_data_works_in_likelihood() -> None:
data = np.array([10, 11, 12, 13, 14, 15])

with pm.Model() as model:
target = pm.Data("target", data)
scale = 12
scaled_target = target / scale
mu = pm.Normal("mu", mu=0, sigma=1)
pm.Normal("x", mu=mu, sigma=1, observed=scaled_target)
Loading