Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,3 @@ dependencies:
- pytest>=4.4.0
- pre-commit>=2.19
- ruff==0.9.1

1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ bambi>=0.13.0
arviz_base>=0.5.0
ruff==0.9.1
numpyro>=0.17.0
numba>=0.60.0
74 changes: 69 additions & 5 deletions simuk/sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
except ImportError:
pass

from collections.abc import Mapping

import numpy as np
from arviz_base import extract, from_dict, from_numpyro
from arviz_base import dict_to_dataset, extract, from_dict, from_numpyro
from tqdm import tqdm


Expand Down Expand Up @@ -59,6 +61,9 @@ class SBC:
data_dir : dict
Keyword arguments passed to numpyro model, intended for use when providing
an MCMC Kernel model.
simulator : callable
A custom simulator function that takes as input the model parameters and
a int parameter named `seed`, and must return a dictionary of named observations.

Example
-------
Expand All @@ -73,7 +78,15 @@ class SBC:
sbc.run_simulations()
"""

def __init__(self, model, num_simulations=1000, sample_kwargs=None, seed=None, data_dir=None):
def __init__(
self,
model,
num_simulations=1000,
sample_kwargs=None,
seed=None,
data_dir=None,
simulator=None,
):
if hasattr(model, "basic_RVs") and isinstance(model, pm.Model):
self.engine = "pymc"
self.model = model
Expand Down Expand Up @@ -110,6 +123,22 @@ def __init__(self, model, num_simulations=1000, sample_kwargs=None, seed=None, d
self._extract_variable_names()
self.simulations = {name: [] for name in self.var_names}
self._simulations_complete = 0
if simulator is not None and not callable(simulator):
raise ValueError("simulator should be a function or None")
if simulator is not None and self.observed_vars:
logging.warning(
"Provided model contains both observed variables and a simulator. "
"Ignoring observed variables and using the simulator instead."
)
if simulator is None and not self.observed_vars and self.engine == "pymc":
# Ideally, we could raise an error early for `numpyro` also,
# but `factor` also produces 'observed_vars'
raise ValueError(
"There are no observed variables, and PyMC will not generate prior "
"predictive samples. Either change the model or specify a simulator "
"with the `simulator` argument."
)
self.simulator = simulator

def _extract_variable_names(self):
"""Extract observed and free variables from the model."""
Expand Down Expand Up @@ -142,8 +171,30 @@ def _get_prior_predictive_samples(self):
idata = pm.sample_prior_predictive(
samples=self.num_simulations, random_seed=self._seeds[0]
)
prior_pred = extract(idata, group="prior_predictive", keep_dataset=True)
prior = extract(idata, group="prior", keep_dataset=True)
if self.simulator is None:
prior_pred = extract(idata, group="prior_predictive", keep_dataset=True)
return prior, prior_pred
# Deal with custom simulator
prior_pred = []
for i in range(prior.sizes["sample"]):
params = {var: prior[var].isel(sample=i).values for var in prior.data_vars}
params["seed"] = self._seeds[i]
try:
res = self.simulator(**params)
assert isinstance(
res, Mapping
), f"Simulator must return a dictionary, got {type(res)}"
prior_pred.append(res)
except Exception as e:
raise ValueError(
f"Error generating prior predictive sample with parameters {params}: {e}."
)
prior_pred = dict_to_dataset(
{key: np.stack([pp[key] for pp in prior_pred]) for key in prior_pred[0]},
sample_dims=["sample"],
coords={**prior.coords},
)
return prior, prior_pred

def _get_prior_predictive_samples_numpyro(self):
Expand All @@ -152,7 +203,15 @@ def _get_prior_predictive_samples_numpyro(self):
free_vars_data = {k: v for k, v in self.data_dir.items() if k not in self.observed_vars}
samples = predictive(jax.random.PRNGKey(self._seeds[0]), **free_vars_data)
prior = {k: v for k, v in samples.items() if k not in self.observed_vars}
prior_pred = {k: v for k, v in samples.items() if k in self.observed_vars}
if self.simulator:
results = []
for i, vals in enumerate(zip(*prior.values())):
params = dict(zip(prior.keys(), vals))
params["seed"] = self._seeds[i]
results.append(self.simulator(**params))
prior_pred = {key: [result[key] for result in results] for key in results[0]}
else:
prior_pred = {k: v for k, v in samples.items() if k in self.observed_vars}
return prior, prior_pred

def _get_posterior_samples(self, prior_predictive_draw):
Expand All @@ -170,7 +229,12 @@ def _get_posterior_samples_numpyro(self, prior_predictive_draw):
"""Generate posterior samples using numpyro conditioned to a prior predictive sample."""
mcmc = MCMC(self.numpyro_model, **self.sample_kwargs)
rng_seed = jax.random.PRNGKey(self._seeds[self._simulations_complete])
free_vars_data = {k: v for k, v in self.data_dir.items() if k not in self.observed_vars}
# If using a custom simulator, some variables present in `prior_predictive_draw`
# might be missing from self.observed_vars.
# TODO: Not sure if the union is redundant here and perhaps prior_predictive_draw.keys()
# could be sufficient.
extended_observed_vars = set(prior_predictive_draw.keys()).union(self.observed_vars)
free_vars_data = {k: v for k, v in self.data_dir.items() if k not in extended_observed_vars}
mcmc.run(rng_seed, **free_vars_data, **prior_predictive_draw)
return from_numpyro(mcmc)["posterior"]

Expand Down
146 changes: 134 additions & 12 deletions simuk/tests/test_sbc.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,88 @@
import bambi as bmb
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
import pandas as pd
import pymc as pm
import pytest
from numba import njit
from numpyro.infer import NUTS

import simuk

np.random.seed(1234)

# Test data
data = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

# PyMC models
with pm.Model() as centered_eight:
mu = pm.Normal("mu", mu=0, sigma=5)
tau = pm.HalfCauchy("tau", beta=5)
theta = pm.Normal("theta", mu=mu, sigma=tau, shape=8)
y_obs = pm.Normal("y", mu=theta, sigma=sigma, observed=data)

with pm.Model() as centered_eight_no_observed:
mu = pm.Normal("mu", mu=0, sigma=5)
tau = pm.HalfCauchy("tau", beta=5)
theta = pm.Normal("theta", mu=mu, sigma=tau, shape=8)

def log_likelihood(theta, observed):
return pm.math.sum(pm.logp(pm.Normal.dist(mu=theta, sigma=sigma), observed))

pm.Potential("y_loglike", log_likelihood(mu, data))

# Bambi model
x = np.random.normal(0, 1, 20)
y = 2 + np.random.normal(x, 1)
df = pd.DataFrame({"x": x, "y": y})
bmb_model = bmb.Model("y ~ x", df)

# NumPyro models
def eight_schools_cauchy_prior(J, sigma, y=None):
mu = numpyro.sample("mu", dist.Normal(0, 5))
tau = numpyro.sample("tau", dist.HalfCauchy(5))
with numpyro.plate("J", J):
theta = numpyro.sample("theta", dist.Normal(mu, tau))
numpyro.sample("y", dist.Normal(theta, sigma), obs=y)


def eight_schools_cauchy_prior_no_observed(J, sigma, y=None):
mu = numpyro.sample("mu", dist.Normal(0, 5))
tau = numpyro.sample("tau", dist.HalfCauchy(5))
with numpyro.plate("J", J):
theta = numpyro.sample("theta", dist.Normal(mu, tau))
if y is not None:
log_likelihood = jnp.sum(dist.Normal(theta, sigma).log_prob(y))
numpyro.factor("custom_likelihood", log_likelihood)


# Custom simulator functions
def centered_eight_simulator(theta, seed, **kwargs):
rng = np.random.default_rng(seed)
return {"y": rng.normal(theta, sigma)}


@njit
def centered_eight_jitted_simulator(tau, mu, theta, seed):
# Some expensive computation
n = theta.shape[0]
y = np.zeros(n)
for i in range(n):
y[i] = theta[i]
return {"y": y}


def bmb_simulator(mu, sigma, seed, **kwargs):
rng = np.random.default_rng(seed)
return {"y": rng.normal(mu, sigma)}


# --- Tests with observed variables ---
@pytest.mark.parametrize("model", [centered_eight, bmb_model])
def test_sbc(model):
def test_sbc_with_observed_data(model):
sbc = simuk.SBC(
model,
num_simulations=10,
Expand All @@ -37,22 +92,89 @@ def test_sbc(model):
assert "prior_sbc" in sbc.simulations


def test_sbc_numpyro():
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
def test_sbc_numpyro_with_observed_data():
sbc = simuk.SBC(
NUTS(eight_schools_cauchy_prior),
data_dir={"J": 8, "sigma": sigma, "y": data},
num_simulations=10,
sample_kwargs={"num_warmup": 50, "num_samples": 25},
)
sbc.run_simulations()
assert "prior_sbc" in sbc.simulations

def eight_schools_cauchy_prior(J, sigma, y=None):
mu = numpyro.sample("mu", dist.Normal(0, 5))
tau = numpyro.sample("tau", dist.HalfCauchy(5))
with numpyro.plate("J", J):
theta = numpyro.sample("theta", dist.Normal(mu, tau))
numpyro.sample("y", dist.Normal(theta, sigma), obs=y)

# --- Tests with custom simulators ---
@pytest.mark.parametrize(
"model,simulator",
[
# Case 1: Both simulator function and observed variables present
(centered_eight, centered_eight_simulator),
# Case 2: Only simulator function present
(centered_eight_no_observed, centered_eight_simulator),
],
)
def test_sbc_with_custom_simulator(model, simulator):
sbc = simuk.SBC(
NUTS(eight_schools_cauchy_prior),
data_dir={"J": 8, "sigma": sigma, "y": y},
model, num_simulations=10, sample_kwargs={"draws": 5, "tune": 5}, simulator=simulator
)
sbc.run_simulations()
assert "prior_sbc" in sbc.simulations


@pytest.mark.skipif(
hasattr(bmb, "__version__") and tuple(map(int, bmb.__version__.split("."))) <= (0, 14),
reason="requires bambi version > 0.14",
)
def test_sbc_bambi_with_custom_simulator():
sbc = simuk.SBC(
bmb_model,
num_simulations=10,
sample_kwargs={"draws": 5, "tune": 5},
simulator=bmb_simulator,
)
sbc.run_simulations()
assert "prior_sbc" in sbc.simulations


@pytest.mark.parametrize(
"model,simulator",
[
# Case 1: Both simulator function and observed variables present
(eight_schools_cauchy_prior, centered_eight_simulator),
# Case 2: Only simulator function present
(eight_schools_cauchy_prior_no_observed, centered_eight_simulator),
],
)
def test_sbc_numpyro_with_custom_simulator(model, simulator):
sbc = simuk.SBC(
NUTS(model),
data_dir={"J": 8, "sigma": sigma, "y": data},
num_simulations=10,
sample_kwargs={"num_warmup": 50, "num_samples": 25},
simulator=simulator,
)
sbc.run_simulations()
assert "prior_sbc" in sbc.simulations


# --- Error handling tests with custom simulators ---
def test_sbc_fail_no_observed_variable():
with pytest.raises(ValueError, match="no observed variables"):
simuk.SBC(
centered_eight_no_observed,
num_simulations=10,
sample_kwargs={"draws": 5, "tune": 5},
)


def test_sbc_numpyro_fail_no_observed_variable():
# Note: factor variables are catalogued as 'observed_vars' in NumPyro
# therefore, we cannot raise an early exception with an informative message
with pytest.raises(ValueError):
sbc = simuk.SBC(
NUTS(eight_schools_cauchy_prior_no_observed),
data_dir={"J": 8, "sigma": sigma, "y": data},
num_simulations=10,
sample_kwargs={"num_warmup": 50, "num_samples": 25},
)
sbc.run_simulations()
Loading