diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index 33c8718ea..f49ca75d9 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -133,6 +133,21 @@ def test_unobserved_categorical(self): assert np.all(np.median(trace["mu"], axis=0) == [1, 2]) + def test_parallel_custom(self): + def _logp(value, mu): + return -((value - mu) ** 2) + + def _random(mu, rng=None, size=None): + return rng.normal(loc=mu, scale=1, size=size) + + def _dist(mu, size=None): + return pm.Normal.dist(mu, 1, size=size) + + with pm.Model(): + mu = pm.CustomDist("mu", 0, logp=_logp, dist=_dist) + pm.CustomDist("y", mu, logp=_logp, class_name="", random=_random, observed=[1, 2]) + pm.sample_smc(draws=6, cores=2) + def test_marginal_likelihood(self): """ Verifies that the log marginal likelihood function