Skip to content

Commit d08754c

Browse files
committed
added tests for SVIWrapper
1 parent 2265057 commit d08754c

File tree

2 files changed

+73
-4
lines changed

2 files changed

+73
-4
lines changed

external_tests/test_numpyro.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,3 +571,72 @@ def _run_inference(self, model, svi, guide_fn):
571571
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
572572
mcmc.run(PRNGKey(0))
573573
return {"posterior": mcmc}
574+
575+
576+
class TestSVIWrapper:
577+
# @pytest.fixture(scope="class", params=["numpyro", "numpyro_svi"])
578+
# def data(self, request):
579+
# from numpyro.infer.svi import SVI, SVIRunResult
580+
# from numpyro.infer import Trace_ELBO
581+
# from numpyro.optim import Adam
582+
# from numpyro import distributions as dist
583+
584+
# def model():
585+
# numpyro.sample("alpha", dist.Normal(0, 10))
586+
587+
# return {
588+
# "numpyro": (numpyro.infer.SVI()
589+
# }[request.param]
590+
591+
# return Data
592+
593+
@pytest.fixture(scope="class", params=["numpyro_svi", "numpyro_svi_custom_guide"])
594+
def data(self, request, eight_schools_params, draws, chains):
595+
class Data:
596+
obj = load_cached_models(eight_schools_params, draws, chains, "numpyro")[request.param]
597+
598+
return Data
599+
600+
# def test_init_defaults(self):
601+
# wrapper = SVIWrapper(
602+
603+
# )
604+
605+
# assert wrapper.svi == "mysvi"
606+
# assert wrapper.svi_result == {"params": {}}
607+
# assert wrapper.num_samples == 1000
608+
# assert wrapper.thinning == 1
609+
# assert wrapper.num_chains == 0
610+
# assert wrapper.sample_dims == ["samples"]
611+
# assert wrapper.kind == "svi"
612+
# assert callable(wrapper.prng_key_func)
613+
# assert hasattr(wrapper, "numpyro")
614+
615+
def test_init_without_args_kwargs(self):
616+
from numpyro.infer import Trace_ELBO
617+
from numpyro.infer.svi import SVI, SVIRunResult
618+
from numpyro.optim import Adam
619+
620+
model = guide = lambda x: x
621+
svi = SVI(model, guide, optim=Adam(0.05), loss=Trace_ELBO())
622+
svi_result = SVIRunResult(params=jax.numpy.ones(5), state=None, losses=jax.numpy.zeros(10))
623+
624+
posterior = SVIWrapper(svi, svi_result=svi_result)
625+
assert isinstance(posterior._args, tuple)
626+
assert isinstance(posterior._kwargs, dict)
627+
628+
def test_get_samples(self, data, eight_schools_params):
629+
svi_posterior = SVIWrapper(
630+
data.obj["svi"], svi_result=data.obj["svi_result"], model_kwargs=eight_schools_params
631+
)
632+
out = svi_posterior.get_samples(seed=0)
633+
assert isinstance(out, dict)
634+
for v in out.values(): # values are array-like
635+
assert isinstance(v, (jax.numpy.ndarray | np.ndarray))
636+
637+
def test_sampler_attr(self, data, eight_schools_params):
638+
svi_posterior = SVIWrapper(
639+
data.obj["svi"], svi_result=data.obj["svi_result"], model_kwargs=eight_schools_params
640+
)
641+
assert hasattr(svi_posterior, "sampler")
642+
assert hasattr(svi_posterior.sampler, "model")

src/arviz_base/io_numpyro.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
class SVIWrapper:
15-
"""A helper class for SVI to mimic MCMC methods."""
15+
"""A helper class for SVI to mimic numpyro.infer.MCMC methods."""
1616

1717
def __init__(
1818
self,
@@ -624,8 +624,8 @@ def from_numpyro(
624624

625625
def from_numpyro_svi(
626626
svi,
627-
svi_result,
628627
*,
628+
svi_result,
629629
model_args=None,
630630
model_kwargs=None,
631631
prior=None,
@@ -663,7 +663,7 @@ def from_numpyro_svi(
663663
664664
Parameters
665665
----------
666-
guide : numpyro.infer.svi.SVI
666+
svi : numpyro.infer.svi.SVI
667667
Numpyro SVI instance used for fitting the model.
668668
svi_result : numpyro.infer.svi.SVIRunResult
669669
SVI results from a fitted model.
@@ -720,5 +720,5 @@ def from_numpyro_svi(
720720
dims=dims,
721721
pred_dims=pred_dims,
722722
extra_event_dims=extra_event_dims,
723-
num_chains=1,
723+
num_chains=0,
724724
).to_datatree()

0 commit comments

Comments
 (0)