Skip to content

Commit 2265057

Browse files
committed
improved comments, imports
1 parent 8f81926 commit 2265057

File tree

3 files changed

+35
-39
lines changed

3 files changed

+35
-39
lines changed

external_tests/helpers.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,22 +117,18 @@ def _numpyro_noncentered_guide(J, sigma, y=None):
117117
# Variational parameters for mu
118118
mu_loc = numpyro.param("mu_loc", 0.0)
119119
mu_scale = numpyro.param("mu_scale", 1.0, constraint=dist.constraints.positive)
120-
mu = numpyro.sample("mu", dist.Normal(mu_loc, mu_scale))
120+
numpyro.sample("mu", dist.Normal(mu_loc, mu_scale))
121121

122122
# Variational parameters for tau (positive support)
123123
tau_loc = numpyro.param("tau_loc", 1.0)
124124
tau_scale = numpyro.param("tau_scale", 0.5, constraint=dist.constraints.positive)
125-
tau = numpyro.sample("tau", dist.LogNormal(jax.numpy.log(tau_loc), tau_scale))
125+
numpyro.sample("tau", dist.LogNormal(jax.numpy.log(tau_loc), tau_scale))
126126

127127
# Variational parameters for eta
128128
eta_loc = numpyro.param("eta_loc", jax.numpy.zeros(J))
129129
eta_scale = numpyro.param("eta_scale", jax.numpy.ones(J), constraint=dist.constraints.positive)
130130
with numpyro.plate("J", J):
131-
eta = numpyro.sample("eta", dist.Normal(eta_loc, eta_scale))
132-
133-
# theta is deterministic; obs is handled in the model
134-
theta = mu + tau * eta
135-
return theta
131+
numpyro.sample("eta", dist.Normal(eta_loc, eta_scale))
136132

137133

138134
def numpyro_schools_model(data, draws, chains):

external_tests/test_numpyro.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,8 @@
1414
PRNGKey = jax.random.PRNGKey
1515
numpyro = importorskip("numpyro")
1616
Predictive = numpyro.infer.Predictive
17-
numpyro.set_host_device_count(2)
18-
dist = numpyro.distributions
19-
AutoNormal = numpyro.infer.autoguide.AutoNormal
20-
AutoDelta = numpyro.infer.autoguide.AutoDelta
2117
autoguide = numpyro.infer.autoguide
18+
numpyro.set_host_device_count(2)
2219

2320

2421
class TestDataNumPyro:
@@ -309,12 +306,13 @@ def model():
309306
"svi,guide_fn",
310307
[
311308
(False, None), # MCMC, guide ignored
312-
(True, AutoDelta), # SVI with AutoDelta
313-
(True, AutoNormal), # SVI with AutoNormal
309+
(True, autoguide.AutoDelta), # SVI with AutoDelta
310+
(True, autoguide.AutoNormal), # SVI with AutoNormal
314311
(True, "custom"), # SVI with custom guide
315312
],
316313
)
317314
def test_infer_dims(self, svi, guide_fn):
315+
import jax.numpy as jnp
318316
import numpyro
319317
import numpyro.distributions as dist
320318

@@ -324,9 +322,9 @@ def model():
324322
_ = numpyro.sample("param", dist.Normal(0, 1))
325323

326324
def guide():
327-
loc = numpyro.param("param_loc", jax.numpy.zeros((10, 5)))
325+
loc = numpyro.param("param_loc", jnp.zeros((10, 5)))
328326
scale = numpyro.param(
329-
"param_scale", jax.numpy.ones((10, 5)), constraint=dist.constraints.positive
327+
"param_scale", jnp.ones((10, 5)), constraint=dist.constraints.positive
330328
)
331329
with numpyro.plate("group2", 5), numpyro.plate("group1", 10):
332330
numpyro.sample("param", dist.Normal(loc, scale))
@@ -348,12 +346,13 @@ def guide():
348346
"svi,guide_fn",
349347
[
350348
(False, None), # MCMC, guide ignored
351-
(True, AutoDelta), # SVI with AutoDelta
352-
(True, AutoNormal), # SVI with AutoNormal
349+
(True, autoguide.AutoDelta), # SVI with AutoDelta
350+
(True, autoguide.AutoNormal), # SVI with AutoNormal
353351
(True, "custom"), # SVI with custom guide
354352
],
355353
)
356354
def test_infer_unsorted_dims(self, svi, guide_fn):
355+
import jax.numpy as jnp
357356
import numpyro
358357
import numpyro.distributions as dist
359358

@@ -367,9 +366,9 @@ def model():
367366
_ = numpyro.sample("param", dist.Normal(0, 1))
368367

369368
def guide():
370-
loc = numpyro.param("param_loc", jax.numpy.zeros((5, 10)))
369+
loc = numpyro.param("param_loc", jnp.zeros((5, 10)))
371370
scale = numpyro.param(
372-
"param_scale", jax.numpy.ones((5, 10)), constraint=dist.constraints.positive
371+
"param_scale", jnp.ones((5, 10)), constraint=dist.constraints.positive
373372
)
374373
group1_plate = numpyro.plate("group1", 10, dim=-1)
375374
group2_plate = numpyro.plate("group2", 5, dim=-2)
@@ -393,12 +392,13 @@ def guide():
393392
"svi,guide_fn",
394393
[
395394
(False, None), # MCMC, guide ignored
396-
(True, AutoDelta), # SVI with AutoDelta
397-
(True, AutoNormal), # SVI with AutoNormal
395+
(True, autoguide.AutoDelta), # SVI with AutoDelta
396+
(True, autoguide.AutoNormal), # SVI with AutoNormal
398397
(True, "custom"), # SVI with custom guide
399398
],
400399
)
401400
def test_infer_dims_no_coords(self, svi, guide_fn):
401+
import jax.numpy as jnp
402402
import numpyro
403403
import numpyro.distributions as dist
404404

@@ -407,10 +407,8 @@ def model():
407407
_ = numpyro.sample("param", dist.Normal(0, 1))
408408

409409
def guide():
410-
loc = numpyro.param("param_loc", jax.numpy.zeros(5))
411-
scale = numpyro.param(
412-
"param_scale", jax.numpy.ones(5), constraint=dist.constraints.positive
413-
)
410+
loc = numpyro.param("param_loc", jnp.zeros(5))
411+
scale = numpyro.param("param_scale", jnp.ones(5), constraint=dist.constraints.positive)
414412
with numpyro.plate("group", 5):
415413
numpyro.sample("param", dist.Normal(loc, scale))
416414

@@ -428,8 +426,8 @@ def guide():
428426
"svi,guide_fn",
429427
[
430428
(False, None), # MCMC, guide ignored
431-
(True, AutoDelta), # SVI with AutoDelta
432-
(True, AutoNormal), # SVI with AutoNormal
429+
(True, autoguide.AutoDelta), # SVI with AutoDelta
430+
(True, autoguide.AutoNormal), # SVI with AutoNormal
433431
(True, "custom"), # SVI with custom guide
434432
],
435433
)
@@ -465,8 +463,8 @@ def guide():
465463
"svi,guide_fn",
466464
[
467465
(False, None), # MCMC, guide ignored
468-
(True, AutoDelta), # SVI with AutoDelta
469-
(True, AutoNormal), # SVI with AutoNormal
466+
(True, autoguide.AutoDelta), # SVI with AutoDelta
467+
(True, autoguide.AutoNormal), # SVI with AutoNormal
470468
(True, "custom"), # SVI with custom guide
471469
],
472470
)
@@ -510,8 +508,8 @@ def guide():
510508
"svi,guide_fn",
511509
[
512510
(False, None), # MCMC, guide ignored
513-
(True, AutoDelta), # SVI with AutoDelta
514-
(True, AutoNormal), # SVI with AutoNormal
511+
(True, autoguide.AutoDelta), # SVI with AutoDelta
512+
(True, autoguide.AutoNormal), # SVI with AutoNormal
515513
(True, "custom"), # SVI with custom guide
516514
],
517515
)
@@ -554,7 +552,7 @@ def test_predictions_infer_dims(
554552
assert inference_data.predictions.obs.dims == (sample_dims + ("J",))
555553
assert "J" in inference_data.predictions.obs.coords
556554

557-
def _run_inference(self, model, svi, guide_fn=autoguide.AutoNormal):
555+
def _run_inference(self, model, svi, guide_fn):
558556
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
559557
from numpyro.optim import Adam
560558

src/arviz_base/io_numpyro.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ def __init__(
2424
num_samples: int = 1000,
2525
thinning: int = 1,
2626
):
27+
import jax
28+
import numpyro
29+
2730
self.svi = svi
2831
self.svi_result = svi_result
2932
self._args = model_args or tuple()
@@ -34,14 +37,13 @@ def __init__(
3437
self.sample_dims = ["samples"]
3538
self.kind = "svi"
3639

40+
self.numpyro = numpyro
41+
self.prng_key_func = jax.random.PRNGKey
42+
3743
def get_samples(self, seed=None, **kwargs):
3844
"""Mimics mcmc.get_samples()."""
39-
import jax
40-
from numpyro.infer import Predictive
41-
from numpyro.infer.autoguide import AutoGuide
42-
43-
key = jax.random.PRNGKey(seed or 0)
44-
if isinstance(self.svi.guide, AutoGuide):
45+
key = self.prng_key_func(seed or 0)
46+
if isinstance(self.svi.guide, self.numpyro.infer.autoguide.AutoGuide):
4547
return self.svi.guide.sample_posterior(
4648
key,
4749
self.svi_result.params,
@@ -50,7 +52,7 @@ def get_samples(self, seed=None, **kwargs):
5052
**self._kwargs,
5153
)
5254
# if a custom guide is provided, sample by hand
53-
predictive = Predictive(
55+
predictive = self.numpyro.infer.Predictive(
5456
self.svi.guide, params=self.svi_result.params, num_samples=self.num_samples
5557
)
5658
samples = predictive(key, *self._args, **self._kwargs)

0 commit comments

Comments
 (0)