Skip to content

Commit 50c184b

Browse files
committed
First draft based on workshop demo
1 parent e6d3390 commit 50c184b

File tree

2 files changed

+196
-0
lines changed

2 files changed

+196
-0
lines changed

pymc/sampling/forward.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
"compile_forward_sampling_function",
6868
"draw",
6969
"sample_posterior_predictive",
70+
"sample_prior",
7071
"sample_prior_predictive",
7172
)
7273

@@ -984,3 +985,91 @@ def sample_posterior_predictive(
984985
idata.extend(idata_pp)
985986
return idata
986987
return idata_pp
988+
989+
990+
def sample_prior(
991+
draws: int = 500,
992+
model: Model | None = None,
993+
var_names: Iterable[str] | None = None,
994+
random_seed: RandomState = None,
995+
return_inferencedata: bool = True,
996+
idata_kwargs: dict | None = None,
997+
compile_kwargs: dict | None = None,
998+
) -> InferenceData | dict[str, np.ndarray]:
999+
"""Generate samples from the prior distribution.
1000+
1001+
This function samples only from the prior (unobserved random variables)
1002+
and deterministics that do not depend on observed variables.
1003+
1004+
This is different from `sample_prior_predictive` which samples from both
1005+
prior and prior predictive distributions.
1006+
1007+
Parameters
1008+
----------
1009+
draws : int
1010+
Number of samples from the prior to generate. Defaults to 500.
1011+
model : Model (optional if in ``with`` context)
1012+
var_names : Iterable[str]
1013+
A list of names of variables for which to compute the prior samples.
1014+
random_seed : int, RandomState or Generator, optional
1015+
Seed for the random number generator.
1016+
return_inferencedata : bool
1017+
Whether to return an :class:`arviz:arviz.InferenceData` (True) object or a dictionary (False).
1018+
Defaults to True.
1019+
idata_kwargs : dict, optional
1020+
Keyword arguments for :func:`pymc.to_inference_data`
1021+
compile_kwargs: dict, optional
1022+
Keyword arguments for :func:`pymc.pytensorf.compile_pymc`.
1023+
1024+
Returns
1025+
-------
1026+
arviz.InferenceData or Dict
1027+
An ArviZ ``InferenceData`` object containing the prior samples (default),
1028+
or a dictionary with variable names as keys and samples as numpy arrays.
1029+
1030+
Examples
1031+
--------
1032+
Basic usage:
1033+
1034+
.. code:: python
1035+
1036+
import pymc as pm
1037+
1038+
with pm.Model() as model:
1039+
mu = pm.Normal("mu", 0, 1)
1040+
sigma = pm.HalfNormal("sigma", 1)
1041+
y = pm.Normal("y", mu, sigma, observed=[1, 2, 3])
1042+
1043+
# Sample only from the prior (mu and sigma)
1044+
prior_samples = pm.sample_prior(draws=1000)
1045+
1046+
Specify specific variables:
1047+
1048+
.. code:: python
1049+
1050+
with model:
1051+
# Sample only mu from the prior
1052+
mu_samples = pm.sample_prior(draws=1000, var_names=["mu"])
1053+
"""
1054+
model = modelcontext(model)
1055+
1056+
if var_names is None:
1057+
# Default to unobserved random variables
1058+
var_names = (var.name for var in model.unobserved_RVs)
1059+
1060+
# Filter out deterministics that depend on observed variables
1061+
dependent_dets = observed_dependent_deterministics(model)
1062+
var_names = (var_name for var_name in var_names if model[var_name] not in dependent_dets)
1063+
1064+
# Use sample_prior_predictive with filtered var_names
1065+
result = sample_prior_predictive(
1066+
draws=draws,
1067+
model=model,
1068+
var_names=var_names,
1069+
random_seed=random_seed,
1070+
return_inferencedata=return_inferencedata,
1071+
idata_kwargs=idata_kwargs,
1072+
compile_kwargs=compile_kwargs,
1073+
)
1074+
1075+
return result

tests/sampling/test_forward.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,3 +1801,110 @@ def test_sample_prior_predictive_samples_deprecated_warns() -> None:
18011801
match = "The samples argument has been deprecated"
18021802
with pytest.warns(DeprecationWarning, match=match):
18031803
pm.sample_prior_predictive(model=m, samples=10)
1804+
1805+
1806+
class TestSamplePrior:
1807+
def test_basic_prior_sampling(self, seeded_test):
1808+
"""Test that sample_prior only samples from unobserved random variables."""
1809+
with pm.Model() as model:
1810+
mu = pm.Normal("mu", 0, 1)
1811+
sigma = pm.HalfNormal("sigma", 1)
1812+
y = pm.Normal("y", mu, sigma, observed=[1, 2, 3])
1813+
det = pm.Deterministic("det", mu + sigma)
1814+
1815+
prior_samples = pm.sample_prior(draws=100, return_inferencedata=False)
1816+
1817+
# Should contain unobserved RVs and deterministics that do not
1818+
# depend on observed variables, but not observed variables
1819+
assert "mu" in prior_samples
1820+
assert "sigma" in prior_samples
1821+
assert "det" in prior_samples # deterministic is included
1822+
assert "y" not in prior_samples # observed variable
1823+
1824+
assert prior_samples["mu"].shape == (100,)
1825+
assert prior_samples["sigma"].shape == (100,)
1826+
assert prior_samples["det"].shape == (100,)
1827+
1828+
def test_specific_var_names(self, seeded_test):
1829+
"""Test sampling specific variables from the prior."""
1830+
with pm.Model() as model:
1831+
mu = pm.Normal("mu", 0, 1)
1832+
sigma = pm.HalfNormal("sigma", 1)
1833+
y = pm.Normal("y", mu, sigma, observed=[1, 2, 3])
1834+
1835+
# Sample only mu
1836+
mu_samples = pm.sample_prior(draws=100, var_names=["mu"], return_inferencedata=False)
1837+
1838+
assert "mu" in mu_samples
1839+
assert "sigma" not in mu_samples
1840+
assert "y" not in mu_samples
1841+
assert mu_samples["mu"].shape == (100,)
1842+
1843+
def test_multivariate_prior(self, seeded_test):
1844+
"""Test sampling from multivariate priors."""
1845+
with pm.Model() as model:
1846+
mu = pm.Normal("mu", 0, 1, size=3)
1847+
# Use a simpler multivariate setup to avoid LKJCholeskyCov issues
1848+
mv = pm.MvNormal("mv", mu, cov=np.eye(3), size=4)
1849+
1850+
prior_samples = pm.sample_prior(draws=50, return_inferencedata=False)
1851+
1852+
assert "mu" in prior_samples
1853+
assert "mv" in prior_samples
1854+
assert prior_samples["mu"].shape == (50, 3)
1855+
assert prior_samples["mv"].shape == (50, 4, 3)
1856+
1857+
def test_only_requested_variables(self, seeded_test):
1858+
"""Test that sample_prior only returns the requested variables."""
1859+
with pm.Model() as model:
1860+
mu = pm.Normal("mu", 0, 1)
1861+
sigma = pm.HalfNormal("sigma", 1)
1862+
det = pm.Deterministic("det", mu + sigma)
1863+
y = pm.Normal("y", det, sigma, observed=[1, 2, 3])
1864+
1865+
# Request only mu, but y depends on det which depends on mu
1866+
prior_samples = pm.sample_prior(draws=100, var_names=["mu"], return_inferencedata=False)
1867+
1868+
# Should only contain mu, not det or sigma even though they're dependencies
1869+
assert "mu" in prior_samples
1870+
assert "sigma" not in prior_samples
1871+
assert "det" not in prior_samples
1872+
assert "y" not in prior_samples
1873+
assert len(prior_samples) == 1
1874+
1875+
def test_deterministics_behavior(self, seeded_test):
1876+
"""Test that sample_prior only includes deterministics that don't depend on observed variables."""
1877+
with pm.Model() as model:
1878+
mu = pm.Normal("mu", 0, 1)
1879+
sigma = pm.HalfNormal("sigma", 1)
1880+
y = pm.Normal("y", mu, sigma, observed=[1, 2, 3])
1881+
1882+
# Deterministic that depends only on unobserved RVs
1883+
det_prior = pm.Deterministic("det_prior", mu + sigma)
1884+
1885+
# Deterministic that depends on observed RV
1886+
det_obs = pm.Deterministic("det_obs", y + mu)
1887+
1888+
prior_samples = pm.sample_prior(draws=100, return_inferencedata=False)
1889+
1890+
# Should include deterministics that depend only on unobserved RVs
1891+
assert "det_prior" in prior_samples
1892+
1893+
# Should NOT include deterministics that depend on observed variables
1894+
assert "det_obs" not in prior_samples
1895+
1896+
# Should not include observed variables
1897+
assert "y" not in prior_samples
1898+
1899+
def test_empty_var_names_behavior(self, seeded_test):
1900+
"""Test what happens when we pass an empty var_names set."""
1901+
with pm.Model() as model:
1902+
mu = pm.Normal("mu", 0, 1)
1903+
sigma = pm.HalfNormal("sigma", 1)
1904+
y = pm.Normal("y", mu, sigma, observed=[1, 2, 3])
1905+
det = pm.Deterministic("det", mu + sigma)
1906+
1907+
# Test with empty var_names
1908+
empty_samples = pm.sample_prior(draws=100, var_names=[], return_inferencedata=False)
1909+
1910+
assert empty_samples == {}

0 commit comments

Comments
 (0)