Skip to content

Commit 16ce329

Browse files
committed
Fix MarginalModel with Data containers
1 parent ba10a00 commit 16ce329

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

pymc_experimental/model/marginal_model.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,10 @@ def logp(self, vars=None, **kwargs):
212212
return m._logp(vars=vars, **kwargs)
213213

214214
def clone(self):
215-
m = MarginalModel()
216-
vars = self.basic_RVs + self.potentials + self.deterministics + self.marginalized_rvs
215+
m = MarginalModel(coords=self.coords)
216+
model_vars = self.basic_RVs + self.potentials + self.deterministics + self.marginalized_rvs
217+
data_vars = [var for name, var in self.named_vars.items() if var not in model_vars]
218+
vars = model_vars + data_vars
217219
cloned_vars = clone_replace(vars)
218220
vars_to_clone = {var: cloned_var for var, cloned_var in zip(vars, cloned_vars)}
219221
m.vars_to_clone = vars_to_clone
@@ -598,7 +600,7 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
598600
# can ultimately be generated that is proportional to the support domain and not
599601
# to the variables dimensions
600602
# We don't need to worry about this if the RV is scalar.
601-
if np.prod(constant_fold(tuple(rv_to_marginalize.shape))) > 1:
603+
if np.prod(constant_fold(tuple(rv_to_marginalize.shape), raise_not_constant=False)) != 1:
602604
if not is_elemwise_subgraph(rv_to_marginalize, dependent_rvs_input_rvs, dependent_rvs):
603605
raise NotImplementedError(
604606
"The subgraph between a marginalized RV and its dependents includes non Elemwise operations. "
@@ -682,7 +684,7 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
682684
# batched dimensions of the marginalized RV
683685

684686
# PyMC does not allow RVs in the logp graph, even if we are just using the shape
685-
marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape))
687+
marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False)
686688
marginalized_rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv)
687689
marginalized_rv_domain_tensor = pt.moveaxis(
688690
pt.full(

pymc_experimental/tests/model/test_marginal_model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,3 +603,29 @@ def test_is_conditional_dependent_static_shape():
603603
x2 = pt.matrix("x2", shape=(9, 5))
604604
y2 = pt.random.normal(size=pt.shape(x2))
605605
assert not is_conditional_dependent(y2, x2, [x2, y2])
606+
607+
608+
def test_data_container():
609+
"""Test that MarginalModel can handle Data containers."""
610+
with MarginalModel(coords_mutable={"obs": [0]}) as marginal_m:
611+
x = pm.MutableData("x", 2.5)
612+
idx = pm.Bernoulli("idx", p=0.7, dims="obs")
613+
y = pm.Normal("y", idx * x, dims="obs")
614+
615+
marginal_m.marginalize([idx])
616+
617+
logp_fn = marginal_m.compile_logp()
618+
619+
with pm.Model(coords_mutable={"obs": [0]}) as m_ref:
620+
x = pm.MutableData("x", 2.5)
621+
y = pm.NormalMixture("y", w=[0.3, 0.7], mu=[0, x], dims="obs")
622+
623+
ref_logp_fn = m_ref.compile_logp()
624+
625+
for i, x_val in enumerate((-1.5, 2.5, 3.5), start=1):
626+
for m in (marginal_m, m_ref):
627+
m.set_dim("obs", new_length=i, coord_values=tuple(range(i)))
628+
pm.set_data({"x": x_val}, model=m)
629+
630+
ip = marginal_m.initial_point()
631+
np.testing.assert_allclose(logp_fn(ip), ref_logp_fn(ip))

0 commit comments

Comments
 (0)