Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 90 additions & 84 deletions external_tests/test_numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import pytest

from arviz_base.io_numpyro import SVIWrapper, from_numpyro, from_numpyro_svi
from arviz_base.io_numpyro import from_numpyro, from_numpyro_svi
from arviz_base.testing import check_multiple_attrs

from .helpers import importorskip, load_cached_models
Expand All @@ -18,6 +18,62 @@
numpyro.set_host_device_count(2)


def _is_svi_data(data_obj):
"""Check if data object is SVI (dict format) or MCMC."""
return isinstance(data_obj, dict)


def _get_model_from_data(data_obj):
"""Extract model from either MCMC or SVI data."""
if _is_svi_data(data_obj):
svi = data_obj["svi"]
return getattr(svi.guide, "model", svi.model)
else:
return data_obj.sampler.model


def _get_samples_from_data(data_obj):
"""Extract samples from either MCMC or SVI data."""
if _is_svi_data(data_obj):
import jax

svi = data_obj["svi"]
svi_result = data_obj["svi_result"]
model_args = data_obj.get("model_args", tuple())
model_kwargs = data_obj.get("model_kwargs", dict())

key = jax.random.PRNGKey(0)
if isinstance(svi.guide, numpyro.infer.autoguide.AutoGuide):
return svi.guide.sample_posterior(
key, svi_result.params, *model_args, sample_shape=(1000,), **model_kwargs
)
predictive = Predictive(svi.guide, params=svi_result.params, num_samples=1000)
return predictive(key, *model_args, **model_kwargs)
else:
return data_obj.get_samples()


def _from_numpyro_auto(data_obj, **kwargs):
"""Test helper to auto-route to correct converter based on data type."""
if _is_svi_data(data_obj):
return from_numpyro_svi(**data_obj, **kwargs)
else:
return from_numpyro(posterior=data_obj, **kwargs)


def _from_numpyro_inference_result(result_dict, **kwargs):
"""Test helper for _run_inference results.

Handles dicts from _run_inference which are either:
- {"svi": svi, "svi_result": result} for SVI
- {"posterior": mcmc} for MCMC
"""
if "svi" in result_dict:
return from_numpyro_svi(**result_dict, **kwargs)
else:
return from_numpyro(**result_dict, **kwargs)


class TestDataNumPyro:
@pytest.fixture(scope="class", params=["numpyro", "numpyro_svi", "numpyro_svi_custom_guide"])
def data(self, request, eight_schools_params, draws, chains):
Expand All @@ -37,9 +93,8 @@ def predictions_params(self):
@pytest.fixture(scope="class")
def predictions_data(self, data, predictions_params):
"""Generate predictions for predictions_params"""
posterior = SVIWrapper(**data.obj) if isinstance(data.obj, dict) else data.obj
posterior_samples = posterior.get_samples()
model = posterior.sampler.model
posterior_samples = _get_samples_from_data(data.obj)
model = _get_model_from_data(data.obj)
predictions = Predictive(model, posterior_samples)(
PRNGKey(2), predictions_params["J"], predictions_params["sigma"]
)
Expand All @@ -48,35 +103,26 @@ def predictions_data(self, data, predictions_params):
def get_inference_data(
self, data, eight_schools_params, predictions_data, predictions_params, infer_dims=False
):
if isinstance(data.obj, dict): # SVI cached data obj is a tuple
posterior = SVIWrapper(**data.obj)
from_numpyro_func = from_numpyro_svi
posterior_kwarg = data.obj
else: # regular MCMC
posterior = data.obj
from_numpyro_func = from_numpyro
posterior_kwarg = {"posterior": posterior}

posterior_samples = posterior.get_samples()
model = posterior.sampler.model
posterior_samples = _get_samples_from_data(data.obj)
model = _get_model_from_data(data.obj)

posterior_predictive = Predictive(model, posterior_samples)(
PRNGKey(1), eight_schools_params["J"], eight_schools_params["sigma"]
)
prior = Predictive(model, num_samples=500)(
PRNGKey(2), eight_schools_params["J"], eight_schools_params["sigma"]
)

dims = {"theta": ["school"], "eta": ["school"], "obs": ["school"]}
pred_dims = {"theta": ["school_pred"], "eta": ["school_pred"], "obs": ["school_pred"]}
if infer_dims:
dims = pred_dims = None

predictions = predictions_data

return from_numpyro_func(
**posterior_kwarg,
return _from_numpyro_auto(
data.obj,
prior=prior,
posterior_predictive=posterior_predictive,
predictions=predictions,
predictions=predictions_data,
coords={
"school": np.arange(eight_schools_params["J"]),
"school_pred": np.arange(predictions_params["J"]),
Expand All @@ -86,7 +132,10 @@ def get_inference_data(
)

def test_inference_data_namedtuple(self, data):
posterior = SVIWrapper(**data.obj) if isinstance(data.obj, dict) else data.obj
if _is_svi_data(data.obj):
pytest.skip("Namedtuple test only applies to MCMC")

posterior = data.obj
samples = posterior.get_samples()
Samples = namedtuple("Samples", samples)
data_namedtuple = Samples(**samples)
Expand Down Expand Up @@ -128,9 +177,11 @@ def test_inference_data(self, data, eight_schools_params, predictions_data, pred
def test_inference_data_no_posterior(
self, data, eight_schools_params, predictions_data, predictions_params
):
posterior = SVIWrapper(**data.obj) if isinstance(data.obj, dict) else data.obj
posterior_samples = posterior.get_samples()
model = posterior.sampler.model
if _is_svi_data(data.obj):
pytest.skip("This test only runs with MCMC (numpyro)")

posterior_samples = _get_samples_from_data(data.obj)
model = _get_model_from_data(data.obj)
posterior_predictive = Predictive(model, posterior_samples)(
PRNGKey(1), eight_schools_params["J"], eight_schools_params["sigma"]
)
Expand Down Expand Up @@ -177,14 +228,12 @@ def test_inference_data_no_posterior(
assert not fails, f"prior and posterior_predictive: {fails}"

def test_inference_data_only_posterior(self, data):
kwargs = data.obj if isinstance(data.obj, dict) else {"posterior": data.obj}
from_numpyro_func = from_numpyro_svi if isinstance(data.obj, dict) else from_numpyro
idata = from_numpyro_func(**kwargs)
idata = _from_numpyro_auto(data.obj)
test_dict = {
"posterior": ["mu", "tau", "eta"],
"sample_stats": ["diverging"],
}
if isinstance(data.obj, dict):
if _is_svi_data(data.obj):
test_dict.pop("sample_stats")
fails = check_multiple_attrs(test_dict, idata)
assert not fails
Expand Down Expand Up @@ -255,7 +304,9 @@ def model_constant_data(x, y1=None):
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails

def test_inference_data_num_chains(self, predictions_data, chains):
def test_inference_data_num_chains(self, data, predictions_data, chains):
if _is_svi_data(data.obj):
pytest.skip("This test only runs with MCMC (numpyro)")
predictions = predictions_data
inference_data = from_numpyro(predictions=predictions, num_chains=chains)
nchains = inference_data.predictions.sizes["chain"]
Expand Down Expand Up @@ -333,11 +384,10 @@ def guide():
guide_fn = guide

result = self._run_inference(model, svi=svi, guide_fn=guide_fn)
from_numpyro_func = from_numpyro_svi if svi else from_numpyro
sample_dims = ("sample",) if svi else ("chain", "draw")

inference_data = from_numpyro_func(
**result, coords={"group1": np.arange(10), "group2": np.arange(5)}
inference_data = _from_numpyro_inference_result(
result, coords={"group1": np.arange(10), "group2": np.arange(5)}
)
assert inference_data.posterior.param.dims == sample_dims + ("group1", "group2")
assert all(dim in inference_data.posterior.param.coords for dim in ("group1", "group2"))
Expand Down Expand Up @@ -379,11 +429,10 @@ def guide():
guide_fn = guide

result = self._run_inference(model, svi=svi, guide_fn=guide_fn)
from_numpyro_func = from_numpyro_svi if svi else from_numpyro
sample_dims = ("sample",) if svi else ("chain", "draw")

inference_data = from_numpyro_func(
**result, coords={"group1": np.arange(10), "group2": np.arange(5)}
inference_data = _from_numpyro_inference_result(
result, coords={"group1": np.arange(10), "group2": np.arange(5)}
)
assert inference_data.posterior.param.dims == sample_dims + ("group2", "group1")
assert all(dim in inference_data.posterior.param.coords for dim in ("group1", "group2"))
Expand Down Expand Up @@ -416,10 +465,9 @@ def guide():
guide_fn = guide

result = self._run_inference(model, svi=svi, guide_fn=guide_fn)
from_numpyro_func = from_numpyro_svi if svi else from_numpyro
sample_dims = ("sample",) if svi else ("chain", "draw")

inference_data = from_numpyro_func(**result)
inference_data = _from_numpyro_inference_result(result)
assert inference_data.posterior.param.dims == sample_dims + ("group",)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -452,10 +500,9 @@ def guide():
guide_fn = guide

result = self._run_inference(model, svi=svi, guide_fn=guide_fn)
from_numpyro_func = from_numpyro_svi if svi else from_numpyro
sample_dims = ("sample",) if svi else ("chain", "draw")

inference_data = from_numpyro_func(**result, coords={"groups": np.arange(10)})
inference_data = _from_numpyro_inference_result(result, coords={"groups": np.arange(10)})
assert inference_data.posterior.gamma.dims == sample_dims + ("groups",)
assert "groups" in inference_data.posterior.gamma.coords

Expand Down Expand Up @@ -500,9 +547,8 @@ def guide():
guide_fn = guide

result = self._run_inference(model, svi=svi, guide_fn=guide_fn)
from_numpyro_func = from_numpyro_svi if svi else from_numpyro
with pytest.raises(ValueError):
from_numpyro_func(**result, coords={"obs_idx": np.arange(3)})
_from_numpyro_inference_result(result, coords={"obs_idx": np.arange(3)})

@pytest.mark.parametrize(
"svi,guide_fn",
Expand Down Expand Up @@ -534,10 +580,9 @@ def guide():
guide_fn = guide

result = self._run_inference(model, svi=svi, guide_fn=guide_fn)
from_numpyro_func = from_numpyro_svi if svi else from_numpyro
sample_dims = ("sample",) if svi else ("chain", "draw")
inference_data = from_numpyro_func(
**result, coords={"groups": np.arange(10)}, extra_event_dims={"gamma_plus1": ["groups"]}
inference_data = _from_numpyro_inference_result(
result, coords={"groups": np.arange(10)}, extra_event_dims={"gamma_plus1": ["groups"]}
)
assert inference_data.posterior.gamma_plus1.dims == sample_dims + ("groups",)
assert "groups" in inference_data.posterior.gamma_plus1.coords
Expand All @@ -548,7 +593,7 @@ def test_predictions_infer_dims(
inference_data = self.get_inference_data(
data, eight_schools_params, predictions_data, predictions_params, infer_dims=True
)
sample_dims = ("sample",) if isinstance(data.obj, dict) else ("chain", "draw")
sample_dims = ("sample",) if _is_svi_data(data.obj) else ("chain", "draw")
assert inference_data.predictions.obs.dims == (sample_dims + ("J",))
assert "J" in inference_data.predictions.obs.coords

Expand All @@ -564,48 +609,9 @@ def _run_inference(self, model, svi, guide_fn):
return {
"svi": svi,
"svi_result": svi_result,
"model": None if is_autoguide else model,
}

else:
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
mcmc.run(PRNGKey(0))
return {"posterior": mcmc}


class TestSVIWrapper:
@pytest.fixture(scope="class", params=["numpyro_svi", "numpyro_svi_custom_guide"])
def data(self, request, eight_schools_params, draws, chains):
class Data:
obj = load_cached_models(eight_schools_params, draws, chains, "numpyro")[request.param]

return Data

def test_init_without_args_kwargs(self):
from numpyro.infer import Trace_ELBO
from numpyro.infer.svi import SVI, SVIRunResult
from numpyro.optim import Adam

model = guide = lambda x: x
svi = SVI(model, guide, optim=Adam(0.05), loss=Trace_ELBO())
svi_result = SVIRunResult(params=jax.numpy.ones(5), state=None, losses=jax.numpy.zeros(10))

posterior = SVIWrapper(svi, svi_result=svi_result)
assert isinstance(posterior._args, tuple)
assert isinstance(posterior._kwargs, dict)

def test_get_samples(self, data, eight_schools_params):
svi_posterior = SVIWrapper(
data.obj["svi"], svi_result=data.obj["svi_result"], model_kwargs=eight_schools_params
)
out = svi_posterior.get_samples(seed=0)
assert isinstance(out, dict)
for v in out.values(): # values are array-like
assert isinstance(v, (jax.numpy.ndarray | np.ndarray))

def test_sampler_attr(self, data, eight_schools_params):
svi_posterior = SVIWrapper(
data.obj["svi"], svi_result=data.obj["svi_result"], model_kwargs=eight_schools_params
)
assert hasattr(svi_posterior, "sampler")
assert hasattr(svi_posterior.sampler, "model")
Loading