|
10 | 10 | from jax.tree_util import register_pytree_node, tree_flatten, tree_unflatten |
11 | 11 |
|
12 | 12 | import numpyro |
| 13 | +import numpyro.distributions as dist |
13 | 14 | from numpyro.primitives import mutable as numpyro_mutable |
14 | 15 |
|
15 | 16 | __all__ = [ |
@@ -223,12 +224,17 @@ def _update_params(params, new_params, prior, prefix=""): |
223 | 224 | new_item = new_params[name] |
224 | 225 | _update_params(item, new_item, prior, prefix=flatten_name) |
225 | 226 | elif (not isinstance(prior, dict)) or flatten_name in prior: |
226 | | - d = prior[flatten_name] if isinstance(prior, dict) else prior |
227 | 227 | if isinstance(params[name], ParamShape): |
228 | 228 | param_shape = params[name].shape |
229 | 229 | else: |
230 | 230 | param_shape = jnp.shape(params[name]) |
231 | 231 | params[name] = ParamShape(param_shape) |
| 232 | + if isinstance(prior, dict): |
| 233 | + d = prior[flatten_name] |
| 234 | + elif callable(prior) and not isinstance(prior, dist.Distribution): |
| 235 | + d = prior(flatten_name, param_shape) |
| 236 | + else: |
| 237 | + d = prior |
232 | 238 | param_batch_shape = param_shape[: len(param_shape) - d.event_dim] |
233 | 239 | # XXX: here we set all dimensions of prior to event dimensions. |
234 | 240 | new_params[name] = numpyro.sample( |
@@ -270,7 +276,12 @@ def __call__(self, x): |
270 | 276 | prior={"bias": dist.Cauchy(), "kernel": dist.Normal()}, |
271 | 277 | input_shape=(4,)) |
272 | 278 |
|
273 | | - :type prior: dict or ~numpyro.distributions.Distribution |
| 279 | + Alternatively, we can use a callable. For example the following are equivalent:: |
| 280 | +
|
| 281 | + prior=(lambda name, shape: dist.Cauchy() if name == "bias" else dist.Normal()) |
| 282 | + prior={"bias": dist.Cauchy(), "kernel": dist.Normal()} |
| 283 | +
|
| 284 | + :type prior: dict, ~numpyro.distributions.Distribution or callable |
274 | 285 | :param tuple input_shape: shape of the input taken by the neural network. |
275 | 286 | :param list apply_rng: A list to indicate which extra rng _kinds_ are needed for |
276 | 287 | ``nn_module``. For example, when ``nn_module`` includes dropout layers, we |
@@ -374,7 +385,12 @@ def random_haiku_module( |
374 | 385 | prior={"linear.b": dist.Cauchy(), "linear.w": dist.Normal()}, |
375 | 386 | input_shape=(4,)) |
376 | 387 |
|
377 | | - :type prior: dict or ~numpyro.distributions.Distribution |
| 388 | + Alternatively, we can use a callable. For example the following are equivalent:: |
| 389 | +
|
| 390 | + prior=(lambda name, shape: dist.Cauchy() if name.startswith("b") else dist.Normal()) |
| 391 | + prior={"bias": dist.Cauchy(), "kernel": dist.Normal()} |
| 392 | +
|
| 393 | + :type prior: dict, ~numpyro.distributions.Distribution or callable |
378 | 394 | :param tuple input_shape: shape of the input taken by the neural network. |
379 | 395 | :param bool apply_rng: A flag to indicate if the returned callable requires |
380 | 396 | an rng argument (e.g. when ``nn_module`` includes dropout layers). Defaults |
|
0 commit comments