Skip to content

Commit f5454aa

Browse files
author
Juan Orduz
authored
Improve rng default in the contrib/model.py module (#1992)
* fix * rm comment * alternative rng * suggestion
1 parent 4b7d233 commit f5454aa

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

numpyro/contrib/module.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ def flax_module(
8585
# feed in dummy data to init params
8686
args = (jnp.ones(input_shape),) if input_shape is not None else args
8787
rng_key = numpyro.prng_key()
88+
if rng_key is None:
89+
rng_key = random.key(0)
8890
# split rng_key into a dict of rng_kind: rng_key
8991
rngs = {}
9092
if apply_rng:
@@ -187,6 +189,8 @@ def haiku_module(name, nn_module, *args, input_shape=None, apply_rng=False, **kw
187189
args = (jnp.ones(input_shape),) if input_shape is not None else args
188190
# feed in dummy data to init params
189191
rng_key = numpyro.prng_key()
192+
if rng_key is None:
193+
rng_key = random.key(0)
190194
if with_state:
191195
nn_params, nn_state = nn_module.init(rng_key, *args, **kwargs)
192196
nn_state = dict(nn_state)

test/infer/test_infer_util.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,3 +519,44 @@ def guide():
519519
assert "guide-always" in called
520520
assert "model-sometimes" not in called
521521
assert "guide-sometimes" not in called
522+
523+
524+
def test_log_likelihood_flax_nn():
525+
import numpy as np
526+
527+
import flax.linen as nn
528+
from jax import random
529+
530+
from numpyro.contrib.module import random_flax_module
531+
532+
# Simulate
533+
rng = np.random.default_rng(99)
534+
N = 1000
535+
536+
X = rng.normal(0, 1, size=(N, 1))
537+
mu = 1 + X @ np.array([0.5])
538+
y = rng.normal(mu, 0.5)
539+
540+
# Simple linear layer
541+
class Linear(nn.Module):
542+
@nn.compact
543+
def __call__(self, x):
544+
return nn.Dense(1, use_bias=True, name="Dense")(x)
545+
546+
def model(X, y=None):
547+
sigma = numpyro.sample("sigma", dist.HalfNormal(0.1))
548+
priors = {"Dense.bias": dist.Normal(0, 2.5), "Dense.kernel": dist.Normal(0, 1)}
549+
mlp = random_flax_module(
550+
"mlp", Linear(), prior=priors, input_shape=(X.shape[1],)
551+
)
552+
with numpyro.plate("data", X.shape[0]):
553+
mu = numpyro.deterministic("mu", mlp(X).squeeze(-1))
554+
y = numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
555+
556+
# Fit model
557+
kernel = NUTS(model, target_accept_prob=0.95)
558+
mcmc = MCMC(kernel, num_warmup=100, num_samples=100, num_chains=1)
559+
mcmc.run(random.PRNGKey(0), X=X, y=y)
560+
561+
# run log likelihood
562+
numpyro.infer.util.log_likelihood(model, mcmc.get_samples(), X=X, y=y)

0 commit comments

Comments
 (0)