Skip to content

Commit de22e71

Browse files
Carlosbogofehiepsi
andauthored
Allow prior to be a callable in random module (#1227)
* Solves #1224 Allows prior to be a user defined function in random module * Changes to #1227 Fixed a mistake in #1227 and added the new funcionality to the docstrings. * Changes to #1227 Added the requested changes to #1227 and a test for the new functionality, namely `test_random_module_mcmc_callable`. * Allow prior to be a callable in random module. Fixes PR #1227 Added the changes mentioned by @fehiepsi to PR #1227 in `test_module`. * Update to PR #1227 Added the changes proposed by @fehiepsi to PR #1227 in order to fix the errors present in it. * Allow prior to be a callable. Changes to PR #1227 Fixes the errors spotted by @fehiepsi in my previous commits for PR #1227. * Run black * Make sure that callable prior is not a distribution Co-authored-by: Du Phan <[email protected]>
1 parent a35e925 commit de22e71

File tree

2 files changed

+29
-10
lines changed

2 files changed

+29
-10
lines changed

numpyro/contrib/module.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from jax.tree_util import register_pytree_node, tree_flatten, tree_unflatten
1111

1212
import numpyro
13+
import numpyro.distributions as dist
1314
from numpyro.primitives import mutable as numpyro_mutable
1415

1516
__all__ = [
@@ -223,12 +224,17 @@ def _update_params(params, new_params, prior, prefix=""):
223224
new_item = new_params[name]
224225
_update_params(item, new_item, prior, prefix=flatten_name)
225226
elif (not isinstance(prior, dict)) or flatten_name in prior:
226-
d = prior[flatten_name] if isinstance(prior, dict) else prior
227227
if isinstance(params[name], ParamShape):
228228
param_shape = params[name].shape
229229
else:
230230
param_shape = jnp.shape(params[name])
231231
params[name] = ParamShape(param_shape)
232+
if isinstance(prior, dict):
233+
d = prior[flatten_name]
234+
elif callable(prior) and not isinstance(prior, dist.Distribution):
235+
d = prior(flatten_name, param_shape)
236+
else:
237+
d = prior
232238
param_batch_shape = param_shape[: len(param_shape) - d.event_dim]
233239
# XXX: here we set all dimensions of prior to event dimensions.
234240
new_params[name] = numpyro.sample(
@@ -270,7 +276,12 @@ def __call__(self, x):
270276
prior={"bias": dist.Cauchy(), "kernel": dist.Normal()},
271277
input_shape=(4,))
272278
273-
:type prior: dict or ~numpyro.distributions.Distribution
279+
Alternatively, we can use a callable. For example the following are equivalent::
280+
281+
prior=(lambda name, shape: dist.Cauchy() if name == "bias" else dist.Normal())
282+
prior={"bias": dist.Cauchy(), "kernel": dist.Normal()}
283+
284+
:type prior: dict, ~numpyro.distributions.Distribution or callable
274285
:param tuple input_shape: shape of the input taken by the neural network.
275286
:param list apply_rng: A list to indicate which extra rng _kinds_ are needed for
276287
``nn_module``. For example, when ``nn_module`` includes dropout layers, we
@@ -374,7 +385,12 @@ def random_haiku_module(
374385
prior={"linear.b": dist.Cauchy(), "linear.w": dist.Normal()},
375386
input_shape=(4,))
376387
377-
:type prior: dict or ~numpyro.distributions.Distribution
388+
Alternatively, we can use a callable. For example the following are equivalent::
389+
390+
prior=(lambda name, shape: dist.Cauchy() if name.startswith("b") else dist.Normal())
391+
prior={"bias": dist.Cauchy(), "kernel": dist.Normal()}
392+
393+
:type prior: dict, ~numpyro.distributions.Distribution or callable
378394
:param tuple input_shape: shape of the input taken by the neural network.
379395
:param bool apply_rng: A flag to indicate if the returned callable requires
380396
an rng argument (e.g. when ``nn_module`` includes dropout layers). Defaults

test/contrib/test_module.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ def test_update_params():
148148

149149
@pytest.mark.parametrize("backend", ["flax", "haiku"])
150150
@pytest.mark.parametrize("init", ["shape", "kwargs"])
151-
def test_random_module_mcmc(backend, init):
151+
@pytest.mark.parametrize("callable_prior", [True, False])
152+
def test_random_module_mcmc(backend, init, callable_prior):
152153

153154
if backend == "flax":
154155
import flax
@@ -179,13 +180,15 @@ def test_random_module_mcmc(backend, init):
179180
elif init == "kwargs":
180181
kwargs = {kwargs_name: data}
181182

182-
def model(data, labels):
183-
nn = random_module(
184-
"nn",
185-
linear_module,
186-
{bias_name: dist.Cauchy(), weight_name: dist.Normal()},
187-
**kwargs
183+
if callable_prior:
184+
prior = (
185+
lambda name, shape: dist.Cauchy() if name == bias_name else dist.Normal()
188186
)
187+
else:
188+
prior = {bias_name: dist.Cauchy(), weight_name: dist.Normal()}
189+
190+
def model(data, labels):
191+
nn = random_module("nn", linear_module, prior=prior, **kwargs)
189192
logits = nn(data).squeeze(-1)
190193
numpyro.sample("y", dist.Bernoulli(logits=logits), obs=labels)
191194

0 commit comments

Comments
 (0)