Skip to content

Commit f5ae79b

Browse files
author
Juan Orduz
authored
Flax NNX integration (#1990)
* This is a combination of 9 commits. init fix fix fix lint rm files rm files rm files rm files rm files * improvement * improvements * add more tests * add mcmc cases * Refactor add example rm test patch feedback 1 feedback 2 feedback 3 improvements fix test hacky way to find batch normalization layers remove unused code simplyfy code modularize tests simplify simplify * support python 3.9 * init example * fix model * cleanup nb * artifacts docs * remove code eager inizialization approach rm unused code 1 rm unused code 2 rm code rm code rm code rm code * rm code * refactor * skip tests * redback 1/n * redback 2/n * lint * skip test * feedback 3/n * partial * feedback states * clean tests part 1 * set priors * better split * rm mutable from signature * fix numpyro_mutable * fix test * feedback part1 * try fix apply function * Revert "try fix apply function" This reverts commit 17db8c1. * Revert "Revert "try fix apply function"" This reverts commit 9a870d6. * rm reduntand line * better test name * are we updating more than needed? * rename and bring back update after model call * add prior component to example * finalize docs * to_dict * stop gradient * finall comments
1 parent 374cd89 commit f5ae79b

File tree

7 files changed

+1008
-4
lines changed

7 files changed

+1008
-4
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,4 @@ repos:
3535
- id: codespell
3636
stages: [pre-commit, commit-msg]
3737
args:
38-
[--ignore-words-list, "Teh,aas,ans", --check-filenames, --skip, "*.ipynb"]
38+
[--ignore-words-list, "Teh,aas,ans,dout", --check-filenames, --skip, "*.ipynb"]
291 KB
Loading

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ NumPyro documentation
3939
tutorials/censoring
4040
tutorials/hsgp_example
4141
tutorials/other_samplers
42+
tutorials/nnx_example
4243

4344
.. nbgallery::
4445
:maxdepth: 1

docs/source/primitives.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ haiku_module
5151
------------
5252
.. autofunction:: numpyro.contrib.module.haiku_module
5353

54+
nnx_module
55+
----------
56+
.. autofunction:: numpyro.contrib.module.nnx_module
57+
5458
random_flax_module
5559
------------------
5660
.. autofunction:: numpyro.contrib.module.random_flax_module
@@ -59,6 +63,10 @@ random_haiku_module
5963
-------------------
6064
.. autofunction:: numpyro.contrib.module.random_haiku_module
6165

66+
random_nnx_module
67+
-----------------
68+
.. autofunction:: numpyro.contrib.module.random_nnx_module
69+
6270
scan
6371
----
6472
.. autofunction:: numpyro.contrib.control_flow.scan

notebooks/source/nnx_example.ipynb

Lines changed: 699 additions & 0 deletions
Large diffs are not rendered by default.

numpyro/contrib/module.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
"haiku_module",
2020
"random_flax_module",
2121
"random_haiku_module",
22+
"nnx_module",
23+
"random_nnx_module",
2224
]
2325

2426

@@ -433,3 +435,134 @@ def random_haiku_module(
433435
_update_params(params, new_params, prior)
434436
nn_new = partial(nn.func, new_params, *nn.args[1:], **nn.keywords)
435437
return nn_new
438+
439+
440+
def nnx_module(name, nn_module):
441+
"""
442+
Declare a :mod:`~flax.nnx` style neural network inside a
443+
model so that its parameters are registered for optimization via
444+
:func:`~numpyro.primitives.param` statements.
445+
446+
Given a flax NNX ``nn_module``, to evaluate the module, we directly call it.
447+
In a NumPyro model, the pattern will be::
448+
449+
# Eager initialization outside the model
450+
module = nn_module(...)
451+
452+
# Inside the model
453+
net = nnx_module("net", module)
454+
y = net(x)
455+
456+
:param str name: name of the module to be registered.
457+
:param flax.nnx.Module nn_module: a pre-initialized `flax nnx` Module instance.
458+
:return: a callable that takes an array as an input and returns
459+
the neural network transformed output array.
460+
"""
461+
try:
462+
from flax import nnx
463+
except ImportError as e:
464+
raise ImportError(
465+
"Looking like you want to use flax.nnx to declare "
466+
"nn modules. This is an experimental feature. "
467+
"You need to install the latest version of `flax` to use this feature. "
468+
"It can be installed with `pip install git+https://github.com/google/flax.git`."
469+
) from e
470+
471+
graph_def, eager_params_state, eager_other_state = nnx.split(
472+
nn_module, nnx.Param, nnx.Not(nnx.Param)
473+
)
474+
475+
eager_params_state_dict = nnx.to_pure_dict(eager_params_state)
476+
477+
module_params = None
478+
if eager_params_state:
479+
module_params = numpyro.param(name + "$params")
480+
if module_params is None:
481+
module_params = numpyro.param(name + "$params", eager_params_state_dict)
482+
483+
eager_other_state_dict = nnx.to_pure_dict(eager_other_state)
484+
485+
mutable_holder = None
486+
if eager_other_state_dict:
487+
mutable_holder = numpyro_mutable(name + "$state")
488+
if mutable_holder is None:
489+
mutable_holder = numpyro_mutable(
490+
name + "$state", {"state": eager_other_state_dict}
491+
)
492+
493+
def apply_fn(params, *call_args, **call_kwargs):
494+
params_state = eager_params_state
495+
496+
if params:
497+
nnx.replace_by_pure_dict(params_state, params)
498+
499+
mutable_state = eager_other_state
500+
if mutable_holder:
501+
nnx.replace_by_pure_dict(mutable_state, mutable_holder["state"])
502+
503+
model = nnx.merge(graph_def, params_state, mutable_state)
504+
505+
model_call = model(*call_args, **call_kwargs)
506+
507+
if mutable_holder:
508+
_, _, new_mutable_state = nnx.split(model, nnx.Param, nnx.Not(nnx.Param))
509+
new_mutable_state = jax.lax.stop_gradient(new_mutable_state)
510+
mutable_holder["state"] = nnx.to_pure_dict(new_mutable_state)
511+
512+
return model_call
513+
514+
return partial(apply_fn, module_params)
515+
516+
517+
def random_nnx_module(
518+
name,
519+
nn_module,
520+
prior,
521+
):
522+
"""
523+
A primitive to create a random :mod:`~flax.nnx` style neural network
524+
which can be used in MCMC samplers. The parameters of the neural network
525+
will be sampled from ``prior``.
526+
527+
:param str name: name of the module to be registered.
528+
:param flax.nnx.Module nn_module: a pre-initialized `flax nnx` Module instance.
529+
:param prior: a distribution or a dict of distributions or a callable.
530+
If it is a distribution, all parameters will be sampled from the same
531+
distribution. If it is a dict, it maps parameter names to distributions.
532+
If it is a callable, it takes parameter name and parameter shape as
533+
inputs and returns a distribution. For example::
534+
535+
class Linear(nnx.Module):
536+
def __init__(self, din, dout, *, rngs):
537+
self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
538+
self.b = nnx.Param(jnp.zeros((dout,)))
539+
540+
def __call__(self, x):
541+
return x @ self.w + self.b
542+
543+
# Eager initialization
544+
linear = Linear(din=4, dout=1, rngs=nnx.Rngs(params=random.PRNGKey(0)))
545+
net = random_nnx_module("net", linear, prior={"w": dist.Normal(), "b": dist.Cauchy()})
546+
547+
Alternatively, we can use a callable. For example the following are equivalent::
548+
549+
prior=(lambda name, shape: dist.Cauchy() if name.endswith("b") else dist.Normal())
550+
prior={"w": dist.Normal(), "b": dist.Cauchy()}
551+
552+
:return: a callable that takes an array as an input and returns
553+
the neural network transformed output array.
554+
"""
555+
556+
nn = nnx_module(name, nn_module)
557+
558+
apply_fn = nn.func
559+
params = nn.args[0]
560+
other_args = nn.args[1:]
561+
keywords = nn.keywords
562+
563+
new_params = deepcopy(params)
564+
565+
with numpyro.handlers.scope(prefix=name):
566+
_update_params(params, new_params, prior)
567+
568+
return partial(apply_fn, new_params, *other_args, **keywords)

test/contrib/test_module.py

Lines changed: 166 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from copy import deepcopy
5+
import sys
56

67
import numpy as np
78
from numpy.testing import assert_allclose
89
import pytest
910

1011
import jax
1112
from jax import random
13+
import jax.numpy as jnp
1214

1315
import numpyro
1416
from numpyro import handlers
@@ -17,8 +19,10 @@
1719
_update_params,
1820
flax_module,
1921
haiku_module,
22+
nnx_module,
2023
random_flax_module,
2124
random_haiku_module,
25+
random_nnx_module,
2226
)
2327
import numpyro.distributions as dist
2428
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
@@ -195,9 +199,9 @@ def test_random_module_mcmc(backend, init, callable_prior):
195199
kwargs = {}
196200

197201
if callable_prior:
198-
prior = ( # noqa: E731
199-
lambda name, shape: dist.Cauchy() if name == bias_name else dist.Normal()
200-
)
202+
203+
def prior(name, shape):
204+
return dist.Cauchy() if name == bias_name else dist.Normal()
201205
else:
202206
prior = {bias_name: dist.Cauchy(), weight_name: dist.Normal()}
203207

@@ -311,3 +315,162 @@ def model():
311315
guide = AutoDelta(model)
312316
svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
313317
svi.run(random.PRNGKey(100), 10)
318+
319+
320+
@pytest.mark.skipif(sys.version_info[:2] == (3, 9), reason="Skipping on Python 3.9")
321+
def test_nnx_module():
322+
from flax import nnx
323+
324+
X = np.arange(100).astype(np.float32)
325+
Y = 2 * X + 2
326+
327+
class Linear(nnx.Module):
328+
def __init__(self, din, dout, *, rngs):
329+
self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
330+
self.bias = nnx.Param(jnp.zeros((dout,)))
331+
332+
def __call__(self, x):
333+
w_val = self.w.value
334+
bias_val = self.bias.value
335+
return x @ w_val + bias_val
336+
337+
# Eager initialization of the Linear module outside the model
338+
rng_key = random.PRNGKey(1)
339+
linear_module = Linear(din=100, dout=100, rngs=nnx.Rngs(params=rng_key))
340+
341+
# Extract parameters and state for inspection
342+
_, params_state = nnx.split(linear_module, nnx.Param)
343+
params_dict = nnx.to_pure_dict(params_state)
344+
345+
# Verify parameters were created correctly
346+
assert "w" in params_dict
347+
assert "bias" in params_dict
348+
assert params_dict["w"].shape == (100, 100)
349+
assert params_dict["bias"].shape == (100,)
350+
351+
# Define a model using eager initialization
352+
def nnx_model_eager(x, y):
353+
# Use the pre-initialized Linear module
354+
nn = nnx_module("nn", linear_module)
355+
mean = nn(x)
356+
numpyro.sample("y", numpyro.distributions.Normal(mean, 0.1), obs=y)
357+
358+
with handlers.trace() as nnx_tr, handlers.seed(rng_seed=1):
359+
nnx_model_eager(X, Y)
360+
361+
assert "w" in nnx_tr["nn$params"]["value"]
362+
assert "bias" in nnx_tr["nn$params"]["value"]
363+
assert nnx_tr["nn$params"]["value"]["w"].shape == (100, 100)
364+
assert nnx_tr["nn$params"]["value"]["bias"].shape == (100,)
365+
366+
367+
@pytest.mark.skipif(sys.version_info[:2] == (3, 9), reason="Skipping on Python 3.9")
368+
@pytest.mark.parametrize(
369+
argnames="dropout", argvalues=[True, False], ids=["dropout", "no_dropout"]
370+
)
371+
@pytest.mark.parametrize(
372+
argnames="batchnorm", argvalues=[True, False], ids=["batchnorm", "no_batchnorm"]
373+
)
374+
def test_nnx_state_dropout_smoke(dropout, batchnorm):
375+
from flax import nnx
376+
377+
class Net(nnx.Module):
378+
def __init__(self, *, rngs):
379+
if batchnorm:
380+
# Use feature dimension 3 to match the input shape (4, 3)
381+
self.bn = nnx.BatchNorm(3, rngs=rngs)
382+
if dropout:
383+
# Create dropout with deterministic=True to disable dropout
384+
self.dropout = nnx.Dropout(rate=0.5, deterministic=True, rngs=rngs)
385+
386+
def __call__(self, x, *, rngs=None):
387+
if dropout:
388+
# Use deterministic=True to disable dropout
389+
x = self.dropout(x, deterministic=True)
390+
391+
if batchnorm:
392+
x = self.bn(x)
393+
394+
return x
395+
396+
# Eager initialization of the Net module outside the model
397+
rng_key = random.PRNGKey(0)
398+
net_module = Net(rngs=nnx.Rngs(params=rng_key))
399+
400+
# Extract parameters and state for inspection
401+
_, state = nnx.split(net_module)
402+
403+
def model():
404+
# Use the pre-initialized module
405+
nn = nnx_module("nn", net_module)
406+
407+
x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2))
408+
y = nn(x)
409+
numpyro.deterministic("y", y)
410+
411+
with handlers.trace(model) as tr, handlers.seed(rng_seed=0):
412+
model()
413+
414+
assert set(tr.keys()) == {"nn$params", "nn$state", "x", "y"}
415+
assert tr["nn$state"]["type"] == "mutable"
416+
417+
# test svi
418+
guide = AutoDelta(model)
419+
svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
420+
svi.run(random.PRNGKey(100), 10)
421+
422+
423+
@pytest.mark.skipif(sys.version_info[:2] == (3, 9), reason="Skipping on Python 3.9")
424+
@pytest.mark.parametrize("callable_prior", [True, False])
425+
def test_random_nnx_module_mcmc(callable_prior):
426+
from flax import nnx
427+
428+
class Linear(nnx.Module):
429+
def __init__(self, din, dout, *, rngs):
430+
self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
431+
self.b = nnx.Param(jnp.zeros((dout,)))
432+
433+
def __call__(self, x):
434+
w_val = self.w
435+
b_val = self.b
436+
return x @ w_val + b_val
437+
438+
N, dim = 3000, 3
439+
data = random.normal(random.PRNGKey(0), (N, dim))
440+
true_coefs = np.arange(1.0, dim + 1.0)
441+
logits = np.sum(true_coefs * data, axis=-1)
442+
labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))
443+
444+
if callable_prior:
445+
446+
def prior(name, shape):
447+
return dist.Cauchy() if name == "b" else dist.Normal()
448+
else:
449+
prior = {"w": dist.Normal(), "b": dist.Cauchy()}
450+
451+
# Create a pre-initialized module for eager initialization
452+
rng_key = random.PRNGKey(0)
453+
linear_module = Linear(din=dim, dout=1, rngs=nnx.Rngs(params=rng_key))
454+
455+
# Extract parameters and state for inspection
456+
_, params_state = nnx.split(linear_module, nnx.Param)
457+
params_dict = nnx.to_pure_dict(params_state)
458+
459+
# Verify parameters were created correctly
460+
assert "w" in params_dict
461+
assert "b" in params_dict
462+
assert params_dict["w"].shape == (dim, 1)
463+
assert params_dict["b"].shape == (1,)
464+
465+
def model(data, labels=None):
466+
# Use the pre-initialized module with eager initialization
467+
nn = random_nnx_module("nn", linear_module, prior)
468+
logits = nn(data).squeeze(-1)
469+
return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)
470+
471+
nuts_kernel = NUTS(model)
472+
mcmc = MCMC(nuts_kernel, num_warmup=2, num_samples=2, progress_bar=False)
473+
mcmc.run(random.PRNGKey(0), data, labels)
474+
samples = mcmc.get_samples()
475+
assert "nn/b" in samples
476+
assert "nn/w" in samples

0 commit comments

Comments
 (0)