Skip to content

Commit da6e49d

Browse files
committed
.wip test
1 parent 2c81057 commit da6e49d

File tree

3 files changed

+116
-79
lines changed

3 files changed

+116
-79
lines changed

pymc_experimental/model/marginal/marginal_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,6 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
592592
extra_batch_ndim = dependent_rv.type.ndim + dependent_rv_ndim_supp - marginal_ndim
593593
valid_dependent_batch_dims = marginal_batch_dims + (((),) * extra_batch_ndim)
594594
if dependent_rv_batch_dims != valid_dependent_batch_dims:
595-
# TODO: This message is too specific
596595
raise NotImplementedError(
597596
f"Link between dimensions of marginalized and dependent RVs not supported: {dependent_rv_batch_dims} != {valid_dependent_batch_dims}"
598597
)

tests/model/marginal/test_graph_analysis.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,21 @@
44
from pymc.distributions import CustomDist
55
from pytensor.tensor.type_other import NoneTypeT
66

7-
from pymc_experimental.model.marginal.graph_analysis import subgraph_batch_dim_connection
7+
from pymc_experimental.model.marginal.graph_analysis import (
8+
is_conditional_dependent,
9+
subgraph_batch_dim_connection,
10+
)
11+
12+
13+
def test_is_conditional_dependent_static_shape():
14+
"""Test that we don't consider dependencies through "constant" shape Ops"""
15+
x1 = pt.matrix("x1", shape=(None, 5))
16+
y1 = pt.random.normal(size=pt.shape(x1))
17+
assert is_conditional_dependent(y1, x1, [x1, y1])
18+
19+
x2 = pt.matrix("x2", shape=(9, 5))
20+
y2 = pt.random.normal(size=pt.shape(x2))
21+
assert not is_conditional_dependent(y2, x2, [x2, y2])
822

923

1024
class TestSubgraphBatchDimConnection:

tests/model/marginal/test_marginal_model.py

Lines changed: 101 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,14 @@
1717
from scipy.special import log_softmax, logsumexp
1818
from scipy.stats import halfnorm, norm
1919

20-
from pymc_experimental.model.marginal.graph_analysis import is_conditional_dependent
2120
from pymc_experimental.model.marginal.marginal_model import (
2221
MarginalModel,
2322
marginalize,
2423
)
2524
from tests.utils import equal_computations_up_to_root
2625

2726

28-
def test_marginalized_basic():
27+
def test_basic_marginalized_rv():
2928
data = [2] * 5
3029

3130
with MarginalModel() as m:
@@ -69,7 +68,8 @@ def test_marginalized_basic():
6968
)
7069

7170

72-
def test_multiple_independent_marginalized_rvs():
71+
def test_one_to_one_marginalized_rvs():
72+
"""Test case with multiple, independent marginalized RVs"""
7373
with MarginalModel() as m:
7474
sigma = pm.HalfNormal("sigma")
7575
idx1 = pm.Bernoulli("idx1", p=0.75)
@@ -95,7 +95,7 @@ def test_multiple_independent_marginalized_rvs():
9595
np.testing.assert_array_almost_equal(y_logp, y_ref_logp)
9696

9797

98-
def test_multiple_dependent_marginalized_rvs():
98+
def test_one_to_many_marginalized_rvs():
9999
"""Test that marginalization works when there is more than one dependent RV"""
100100
with MarginalModel() as m:
101101
sigma = pm.HalfNormal("sigma")
@@ -118,7 +118,37 @@ def test_multiple_dependent_marginalized_rvs():
118118
np.testing.assert_array_almost_equal(logp_x_y, ref_logp_x_y)
119119

120120

121-
def test_rv_dependent_multiple_marginalized_rvs():
121+
def test_one_to_many_unaligned_marginalized_rvs():
122+
"""Test that marginalization works when there is more than one dependent RV with batch dimensions that are not aligned"""
123+
124+
def build_model(build_batched: bool):
125+
with MarginalModel() as m:
126+
if build_batched:
127+
idx = pm.Bernoulli("idx", p=[0.75, 0.4], shape=(3, 2))
128+
else:
129+
idxs = [pm.Bernoulli(f"idx_{i}", p=(0.75 if i % 2 == 0 else 0.4)) for i in range(6)]
130+
idx = pt.stack(idxs, axis=0).reshape(3, 2)
131+
132+
x = pm.Normal("x", mu=idx.T[:, :, None], shape=(2, 3, 1))
133+
y = pm.Normal("y", mu=(idx * 2 - 1), shape=(1, 3, 2))
134+
135+
return m
136+
137+
m = build_model(build_batched=True)
138+
with pytest.warns(UserWarning, match="There are multiple dependent variables"):
139+
m.marginalize(["idx"])
140+
141+
ref_m = build_model(build_batched=False)
142+
ref_m.marginalize([f"idx_{i}" for i in range(6)])
143+
144+
test_point = m.initial_point()
145+
np.testing.assert_allclose(
146+
m.compile_logp()(test_point),
147+
ref_m.compile_logp()(test_point),
148+
)
149+
150+
151+
def test_many_to_one_marginalized_rvs():
122152
"""Test when random variables depend on multiple marginalized variables"""
123153
with MarginalModel() as m:
124154
x = pm.Bernoulli("x", 0.1)
@@ -133,13 +163,13 @@ def test_rv_dependent_multiple_marginalized_rvs():
133163
np.testing.assert_allclose(np.exp(logp({"z": 2})), 0.1 * 0.3)
134164

135165

136-
@pytest.mark.parametrize("batched", (False, True))
166+
@pytest.mark.parametrize("batched", (False, "left", "right"))
137167
def test_nested_marginalized_rvs(batched):
138168
"""Test that marginalization works when there are nested marginalized RVs"""
139169

140170
def build_model(build_batched: bool) -> MarginalModel:
141171
idx_shape = (3,) if build_batched else ()
142-
sub_idx_shape = (3, 5) if build_batched else (5,)
172+
sub_idx_shape = (5,) if not build_batched else (5, 3) if batched == "left" else (3, 5)
143173

144174
with MarginalModel() as m:
145175
sigma = pm.HalfNormal("sigma")
@@ -148,9 +178,9 @@ def build_model(build_batched: bool) -> MarginalModel:
148178
dep = pm.Normal("dep", mu=pt.switch(pt.eq(idx, 0), -1000.0, 1000.0), sigma=sigma)
149179

150180
sub_idx_p = pt.switch(pt.eq(idx, 0), 0.15, 0.95)
151-
if build_batched:
152-
sub_idx_p = sub_idx_p[:, None]
153-
dep = dep[:, None]
181+
if build_batched and batched == "right":
182+
sub_idx_p = sub_idx_p[..., None]
183+
dep = dep[..., None]
154184
sub_idx = pm.Bernoulli("sub_idx", p=sub_idx_p, shape=sub_idx_shape)
155185
sub_dep = pm.Normal("sub_dep", mu=dep + sub_idx * 100, sigma=sigma)
156186

@@ -204,22 +234,22 @@ def test_marginalized_index_as_key(advanced_indexing):
204234

205235
with MarginalModel() as m:
206236
x = pm.Categorical("x", p=w, shape=shape)
207-
y = pm.Normal("y", mu[x], sigma=1, observed=y_val)
237+
y = pm.Normal("y", mu[x].T, sigma=1, observed=y_val)
208238

209239
m.marginalize(x)
210240

211241
marginal_logp = m.compile_logp(sum=False)({})[0]
212-
ref_logp = pm.logp(pm.NormalMixture.dist(w=w, mu=mu, sigma=1, shape=shape), y_val).eval()
242+
ref_logp = pm.logp(pm.NormalMixture.dist(w=w, mu=mu.T, sigma=1, shape=shape), y_val).eval()
213243

214244
np.testing.assert_allclose(marginal_logp, ref_logp)
215245

216246

217247
def test_marginalized_index_as_value_and_key():
218248
"""Test we can marginalize graphs were marginalized_rv is indexed."""
219249

220-
def build_model(batch: bool) -> MarginalModel:
250+
def build_model(build_batched: bool) -> MarginalModel:
221251
with MarginalModel() as m:
222-
if batch:
252+
if build_batched:
223253
latent_state = pm.Bernoulli("latent_state", p=0.3, size=(4,))
224254
else:
225255
latent_state = pm.math.stack(
@@ -237,8 +267,8 @@ def build_model(batch: bool) -> MarginalModel:
237267
return m
238268

239269
# We compare with the equivalent but less efficient batched model
240-
m = build_model(batch=True)
241-
ref_m = build_model(batch=False)
270+
m = build_model(build_batched=True)
271+
ref_m = build_model(build_batched=False)
242272

243273
m.marginalize(["latent_state"])
244274
ref_m.marginalize([f"latent_state_{i}" for i in range(4)])
@@ -317,6 +347,14 @@ def test_mixed_dims_via_support_dimension(self):
317347
with pytest.raises(NotImplementedError):
318348
m.marginalize(x)
319349

350+
def test_mixed_dims_via_nested_marginalization(self):
351+
with MarginalModel() as m:
352+
x = pm.Bernoulli("x", p=0.7, shape=(3,))
353+
y = pm.Bernoulli("y", p=0.7, shape=(2,))
354+
z = pm.Normal("z", mu=pt.add.outer(x, y), shape=(3, 2))
355+
with pytest.raises(NotImplementedError):
356+
m.marginalize([x, y])
357+
320358

321359
def test_marginalized_deterministic_and_potential():
322360
rng = np.random.default_rng(299)
@@ -432,17 +470,6 @@ def test_marginalized_transforms(transform, expected_warning):
432470
np.testing.assert_allclose(m.compile_logp()(ip), m_ref.compile_logp()(ip))
433471

434472

435-
def test_is_conditional_dependent_static_shape():
436-
"""Test that we don't consider dependencies through "constant" shape Ops"""
437-
x1 = pt.matrix("x1", shape=(None, 5))
438-
y1 = pt.random.normal(size=pt.shape(x1))
439-
assert is_conditional_dependent(y1, x1, [x1, y1])
440-
441-
x2 = pt.matrix("x2", shape=(9, 5))
442-
y2 = pt.random.normal(size=pt.shape(x2))
443-
assert not is_conditional_dependent(y2, x2, [x2, y2])
444-
445-
446473
def test_data_container():
447474
"""Test that MarginalModel can handle Data containers."""
448475
with MarginalModel(coords={"obs": [0]}) as marginal_m:
@@ -469,49 +496,6 @@ def test_data_container():
469496
np.testing.assert_allclose(logp_fn(ip), ref_logp_fn(ip))
470497

471498

472-
@pytest.mark.parametrize("univariate", (True, False))
473-
def test_vector_univariate_mixture(univariate):
474-
with MarginalModel() as m:
475-
idx = pm.Bernoulli("idx", p=0.5, shape=(2,) if univariate else ())
476-
477-
def dist(idx, size):
478-
return pm.math.switch(
479-
pm.math.eq(idx, 0),
480-
pm.Normal.dist([-10, -10], 1),
481-
pm.Normal.dist([10, 10], 1),
482-
)
483-
484-
pm.CustomDist("norm", idx, dist=dist)
485-
486-
m.marginalize(idx)
487-
logp_fn = m.compile_logp()
488-
489-
if univariate:
490-
with pm.Model() as ref_m:
491-
pm.NormalMixture("norm", w=[0.5, 0.5], mu=[[-10, 10], [-10, 10]], shape=(2,))
492-
else:
493-
with pm.Model() as ref_m:
494-
pm.Mixture(
495-
"norm",
496-
w=[0.5, 0.5],
497-
comp_dists=[
498-
pm.MvNormal.dist([-10, -10], np.eye(2)),
499-
pm.MvNormal.dist([10, 10], np.eye(2)),
500-
],
501-
shape=(2,),
502-
)
503-
ref_logp_fn = ref_m.compile_logp()
504-
505-
for test_value in (
506-
[-10, -10],
507-
[10, 10],
508-
[-10, 10],
509-
[-10, 10],
510-
):
511-
pt = {"norm": test_value}
512-
np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt))
513-
514-
515499
def test_mutable_indexing_jax_backend():
516500
pytest.importorskip("jax")
517501
from pymc.sampling.jax import get_jaxified_logp
@@ -631,11 +615,51 @@ def test_change_point_model_sampling(self, disaster_model):
631615
rtol=1e-2,
632616
)
633617

634-
@pytest.mark.parametrize(
635-
"batch_right", (True, pytest.param(False, marks=pytest.mark.xfail(reason="NotImplemented")))
636-
)
618+
@pytest.mark.parametrize("univariate", (True, False))
619+
def test_vector_univariate_mixture(self, univariate):
620+
with MarginalModel() as m:
621+
idx = pm.Bernoulli("idx", p=0.5, shape=(2,) if univariate else ())
622+
623+
def dist(idx, size):
624+
return pm.math.switch(
625+
pm.math.eq(idx, 0),
626+
pm.Normal.dist([-10, -10], 1),
627+
pm.Normal.dist([10, 10], 1),
628+
)
629+
630+
pm.CustomDist("norm", idx, dist=dist)
631+
632+
m.marginalize(idx)
633+
logp_fn = m.compile_logp()
634+
635+
if univariate:
636+
with pm.Model() as ref_m:
637+
pm.NormalMixture("norm", w=[0.5, 0.5], mu=[[-10, 10], [-10, 10]], shape=(2,))
638+
else:
639+
with pm.Model() as ref_m:
640+
pm.Mixture(
641+
"norm",
642+
w=[0.5, 0.5],
643+
comp_dists=[
644+
pm.MvNormal.dist([-10, -10], np.eye(2)),
645+
pm.MvNormal.dist([10, 10], np.eye(2)),
646+
],
647+
shape=(2,),
648+
)
649+
ref_logp_fn = ref_m.compile_logp()
650+
651+
for test_value in (
652+
[-10, -10],
653+
[10, 10],
654+
[-10, 10],
655+
[-10, 10],
656+
):
657+
pt = {"norm": test_value}
658+
np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt))
659+
660+
@pytest.mark.parametrize("batch_right", (True, False))
637661
def test_k_censored_clusters_model(self, batch_right):
638-
def build_model(batch: bool) -> MarginalModel:
662+
def build_model(build_batched: bool) -> MarginalModel:
639663
data = np.array([[-1.0, -1.0], [0.0, 0.0], [1.0, 1.0]])
640664
nobs = data.shape[0]
641665
n_clusters = 5
@@ -645,7 +669,7 @@ def build_model(batch: bool) -> MarginalModel:
645669
"obs": range(nobs),
646670
}
647671
with MarginalModel(coords=coords) as m:
648-
if batch:
672+
if build_batched:
649673
idx = pm.Categorical("idx", p=np.ones(n_clusters) / n_clusters, dims=["obs"])
650674
else:
651675
idx = pm.math.stack(
@@ -682,8 +706,8 @@ def build_model(batch: bool) -> MarginalModel:
682706

683707
return m
684708

685-
m = build_model(batch=True)
686-
ref_m = build_model(batch=False)
709+
m = build_model(build_batched=True)
710+
ref_m = build_model(build_batched=False)
687711

688712
m.marginalize([m["idx"]])
689713
ref_m.marginalize([n for n in ref_m.named_vars if n.startswith("idx_")])

0 commit comments

Comments
 (0)