Skip to content

Commit 8dffcd2

Browse files
committed
changed input for form_numpyro_svi from guide to SVI instance
1 parent dae0282 commit 8dffcd2

File tree

4 files changed

+16
-23
lines changed

4 files changed

+16
-23
lines changed

external_tests/helpers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def numpyro_schools_model_svi(data, draws, chains):
169169
guide = AutoNormal(_numpyro_noncentered_model, init_loc_fn=init_to_sample())
170170
svi = SVI(_numpyro_noncentered_model, guide=guide, optim=Adam(0.05), loss=Trace_ELBO())
171171
svi_result = svi.run(PRNGKey(0), 4000, **data)
172-
return {"guide": guide, "svi_result": svi_result, "model_kwargs": data}
172+
return {"svi": svi, "svi_result": svi_result, "model_kwargs": data}
173173

174174

175175
def numpyro_schools_model_svi_custom_guide(data, draws, chains):
@@ -182,10 +182,9 @@ def numpyro_schools_model_svi_custom_guide(data, draws, chains):
182182
svi = SVI(_numpyro_noncentered_model, guide=guide, optim=Adam(0.05), loss=Trace_ELBO())
183183
svi_result = svi.run(PRNGKey(0), 4000, **data)
184184
return {
185-
"guide": guide,
185+
"svi": svi,
186186
"svi_result": svi_result,
187187
"model_kwargs": data,
188-
"model": _numpyro_noncentered_model,
189188
}
190189

191190

external_tests/test_numpyro.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ def _run_inference(self, model, svi, guide_fn=autoguide.AutoNormal):
564564
svi = SVI(model, guide=guide, optim=Adam(0.05), loss=Trace_ELBO())
565565
svi_result = svi.run(PRNGKey(0), 10)
566566
return {
567-
"guide": guide,
567+
"svi": svi,
568568
"svi_result": svi_result,
569569
"model": None if is_autoguide else model,
570570
}

src/arviz_base/io_numpyro.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,15 @@ class SVIWrapper:
1616

1717
def __init__(
1818
self,
19-
guide,
19+
svi,
2020
*,
2121
svi_result,
2222
model_args=None,
2323
model_kwargs=None,
2424
num_samples: int = 1000,
25-
model=None,
2625
thinning: int = 1,
2726
):
28-
self.guide = guide
27+
self.svi = svi
2928
self.svi_result = svi_result
3029
self._args = model_args or tuple()
3130
self._kwargs = model_kwargs or dict()
@@ -34,7 +33,6 @@ def __init__(
3433
self.num_chains = 0
3534
self.sample_dims = ["samples"]
3635
self.kind = "svi"
37-
self.model = model
3836

3937
def get_samples(self, seed=None, **kwargs):
4038
"""Mimics mcmc.get_samples()."""
@@ -43,8 +41,8 @@ def get_samples(self, seed=None, **kwargs):
4341
from numpyro.infer.autoguide import AutoGuide
4442

4543
key = jax.random.PRNGKey(seed or 0)
46-
if isinstance(self.guide, AutoGuide):
47-
return self.guide.sample_posterior(
44+
if isinstance(self.svi.guide, AutoGuide):
45+
return self.svi.guide.sample_posterior(
4846
key,
4947
self.svi_result.params,
5048
*self._args,
@@ -53,7 +51,7 @@ def get_samples(self, seed=None, **kwargs):
5351
)
5452
# if a custom guide is provided, sample by hand
5553
predictive = Predictive(
56-
self.guide, params=self.svi_result.params, num_samples=self.num_samples
54+
self.svi.guide, params=self.svi_result.params, num_samples=self.num_samples
5755
)
5856
samples = predictive(key, *self._args, **self._kwargs)
5957
return samples
@@ -70,7 +68,7 @@ def __init__(self, model):
7068
def model(self):
7169
return self._model
7270

73-
return Sampler(getattr(self.guide, "model", self.model))
71+
return Sampler(getattr(self.svi.guide, "model", self.svi.model))
7472

7573
def get_extra_fields(self, **kwargs):
7674
"""Mimics mcmc.get_extra_fields()."""
@@ -623,7 +621,7 @@ def from_numpyro(
623621

624622

625623
def from_numpyro_svi(
626-
guide,
624+
svi,
627625
svi_result,
628626
*,
629627
model_args=None,
@@ -663,8 +661,8 @@ def from_numpyro_svi(
663661
664662
Parameters
665663
----------
666-
guide : numpyro.infer.autoguide.AutoGuide or callable
667-
Guide function for a numpyro SVI model. Can be an autoguide or custom guide.
664+
guide : numpyro.infer.svi.SVI
665+
Numpyro SVI instance used for fitting the model.
668666
svi_result : numpyro.infer.svi.SVIRunResult
669667
SVI results from a fitted model.
670668
model_args : tuple, optional
@@ -694,20 +692,17 @@ def from_numpyro_svi(
694692
their coordinates.
695693
num_chains : int, default 1
696694
Number of chains used for sampling. Ignored if posterior is present.
697-
model : callable, optional
698-
Model function, only needed for a custom guide function
699695
700696
Returns
701697
-------
702698
DataTree
703699
"""
704700
posterior = SVIWrapper(
705-
guide,
701+
svi,
706702
svi_result=svi_result,
707703
model_args=model_args,
708704
model_kwargs=model_kwargs,
709705
num_samples=num_samples,
710-
model=model,
711706
)
712707
with rc_context(rc={"data.sample_dims": ["samples"]}):
713708
return NumPyroConverter(

src/arviz_base/io_numpyro.pyi

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,12 @@ from arviz_base.utils import expand_dims
1818
class SVIWrapper:
1919
def __init__(
2020
self,
21-
guide,
21+
svi,
2222
*,
2323
svi_result,
2424
model_args=...,
2525
model_kwargs=...,
2626
num_samples: int = ...,
27-
model=...,
2827
thinning: int = ...,
2928
) -> None: ...
3029
def get_samples(self, seed=..., **kwargs) -> None: ...
@@ -96,7 +95,7 @@ def from_numpyro(
9695
num_chains: int = ...,
9796
) -> xarray.DataTree: ...
9897
def from_numpyro_svi(
99-
guide: numpyro.infer.autoguide.AutoGuide | Callable,
98+
svi,
10099
svi_result: numpyro.infer.svi.SVIRunResult,
101100
*,
102101
model_args: tuple | None = ...,
@@ -112,6 +111,6 @@ def from_numpyro_svi(
112111
dims: dict[str, list[str]] | None = ...,
113112
pred_dims: dict | None = ...,
114113
extra_event_dims: dict | None = ...,
115-
model: Callable | None = ...,
114+
model=...,
116115
num_samples: int = ...,
117116
) -> xarray.DataTree: ...

0 commit comments

Comments
 (0)