|
2 | 2 | # SPDX-License-Identifier: Apache-2.0
|
3 | 3 |
|
4 | 4 | from copy import deepcopy
|
| 5 | +import sys |
5 | 6 |
|
6 | 7 | import numpy as np
|
7 | 8 | from numpy.testing import assert_allclose
|
8 | 9 | import pytest
|
9 | 10 |
|
10 | 11 | import jax
|
11 | 12 | from jax import random
|
| 13 | +import jax.numpy as jnp |
12 | 14 |
|
13 | 15 | import numpyro
|
14 | 16 | from numpyro import handlers
|
|
17 | 19 | _update_params,
|
18 | 20 | flax_module,
|
19 | 21 | haiku_module,
|
| 22 | + nnx_module, |
20 | 23 | random_flax_module,
|
21 | 24 | random_haiku_module,
|
| 25 | + random_nnx_module, |
22 | 26 | )
|
23 | 27 | import numpyro.distributions as dist
|
24 | 28 | from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
|
@@ -195,9 +199,9 @@ def test_random_module_mcmc(backend, init, callable_prior):
|
195 | 199 | kwargs = {}
|
196 | 200 |
|
197 | 201 | if callable_prior:
|
198 |
| - prior = ( # noqa: E731 |
199 |
| - lambda name, shape: dist.Cauchy() if name == bias_name else dist.Normal() |
200 |
| - ) |
| 202 | + |
| 203 | + def prior(name, shape): |
| 204 | + return dist.Cauchy() if name == bias_name else dist.Normal() |
201 | 205 | else:
|
202 | 206 | prior = {bias_name: dist.Cauchy(), weight_name: dist.Normal()}
|
203 | 207 |
|
@@ -311,3 +315,162 @@ def model():
|
311 | 315 | guide = AutoDelta(model)
|
312 | 316 | svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
|
313 | 317 | svi.run(random.PRNGKey(100), 10)
|
| 318 | + |
| 319 | + |
| 320 | +@pytest.mark.skipif(sys.version_info[:2] == (3, 9), reason="Skipping on Python 3.9") |
| 321 | +def test_nnx_module(): |
| 322 | + from flax import nnx |
| 323 | + |
| 324 | + X = np.arange(100).astype(np.float32) |
| 325 | + Y = 2 * X + 2 |
| 326 | + |
| 327 | + class Linear(nnx.Module): |
| 328 | + def __init__(self, din, dout, *, rngs): |
| 329 | + self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout))) |
| 330 | + self.bias = nnx.Param(jnp.zeros((dout,))) |
| 331 | + |
| 332 | + def __call__(self, x): |
| 333 | + w_val = self.w.value |
| 334 | + bias_val = self.bias.value |
| 335 | + return x @ w_val + bias_val |
| 336 | + |
| 337 | + # Eager initialization of the Linear module outside the model |
| 338 | + rng_key = random.PRNGKey(1) |
| 339 | + linear_module = Linear(din=100, dout=100, rngs=nnx.Rngs(params=rng_key)) |
| 340 | + |
| 341 | + # Extract parameters and state for inspection |
| 342 | + _, params_state = nnx.split(linear_module, nnx.Param) |
| 343 | + params_dict = nnx.to_pure_dict(params_state) |
| 344 | + |
| 345 | + # Verify parameters were created correctly |
| 346 | + assert "w" in params_dict |
| 347 | + assert "bias" in params_dict |
| 348 | + assert params_dict["w"].shape == (100, 100) |
| 349 | + assert params_dict["bias"].shape == (100,) |
| 350 | + |
| 351 | + # Define a model using eager initialization |
| 352 | + def nnx_model_eager(x, y): |
| 353 | + # Use the pre-initialized Linear module |
| 354 | + nn = nnx_module("nn", linear_module) |
| 355 | + mean = nn(x) |
| 356 | + numpyro.sample("y", numpyro.distributions.Normal(mean, 0.1), obs=y) |
| 357 | + |
| 358 | + with handlers.trace() as nnx_tr, handlers.seed(rng_seed=1): |
| 359 | + nnx_model_eager(X, Y) |
| 360 | + |
| 361 | + assert "w" in nnx_tr["nn$params"]["value"] |
| 362 | + assert "bias" in nnx_tr["nn$params"]["value"] |
| 363 | + assert nnx_tr["nn$params"]["value"]["w"].shape == (100, 100) |
| 364 | + assert nnx_tr["nn$params"]["value"]["bias"].shape == (100,) |
| 365 | + |
| 366 | + |
| 367 | +@pytest.mark.skipif(sys.version_info[:2] == (3, 9), reason="Skipping on Python 3.9") |
| 368 | +@pytest.mark.parametrize( |
| 369 | + argnames="dropout", argvalues=[True, False], ids=["dropout", "no_dropout"] |
| 370 | +) |
| 371 | +@pytest.mark.parametrize( |
| 372 | + argnames="batchnorm", argvalues=[True, False], ids=["batchnorm", "no_batchnorm"] |
| 373 | +) |
| 374 | +def test_nnx_state_dropout_smoke(dropout, batchnorm): |
| 375 | + from flax import nnx |
| 376 | + |
| 377 | + class Net(nnx.Module): |
| 378 | + def __init__(self, *, rngs): |
| 379 | + if batchnorm: |
| 380 | + # Use feature dimension 3 to match the input shape (4, 3) |
| 381 | + self.bn = nnx.BatchNorm(3, rngs=rngs) |
| 382 | + if dropout: |
| 383 | + # Create dropout with deterministic=True to disable dropout |
| 384 | + self.dropout = nnx.Dropout(rate=0.5, deterministic=True, rngs=rngs) |
| 385 | + |
| 386 | + def __call__(self, x, *, rngs=None): |
| 387 | + if dropout: |
| 388 | + # Use deterministic=True to disable dropout |
| 389 | + x = self.dropout(x, deterministic=True) |
| 390 | + |
| 391 | + if batchnorm: |
| 392 | + x = self.bn(x) |
| 393 | + |
| 394 | + return x |
| 395 | + |
| 396 | + # Eager initialization of the Net module outside the model |
| 397 | + rng_key = random.PRNGKey(0) |
| 398 | + net_module = Net(rngs=nnx.Rngs(params=rng_key)) |
| 399 | + |
| 400 | + # Extract parameters and state for inspection |
| 401 | + _, state = nnx.split(net_module) |
| 402 | + |
| 403 | + def model(): |
| 404 | + # Use the pre-initialized module |
| 405 | + nn = nnx_module("nn", net_module) |
| 406 | + |
| 407 | + x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2)) |
| 408 | + y = nn(x) |
| 409 | + numpyro.deterministic("y", y) |
| 410 | + |
| 411 | + with handlers.trace(model) as tr, handlers.seed(rng_seed=0): |
| 412 | + model() |
| 413 | + |
| 414 | + assert set(tr.keys()) == {"nn$params", "nn$state", "x", "y"} |
| 415 | + assert tr["nn$state"]["type"] == "mutable" |
| 416 | + |
| 417 | + # test svi |
| 418 | + guide = AutoDelta(model) |
| 419 | + svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO()) |
| 420 | + svi.run(random.PRNGKey(100), 10) |
| 421 | + |
| 422 | + |
| 423 | +@pytest.mark.skipif(sys.version_info[:2] == (3, 9), reason="Skipping on Python 3.9") |
| 424 | +@pytest.mark.parametrize("callable_prior", [True, False]) |
| 425 | +def test_random_nnx_module_mcmc(callable_prior): |
| 426 | + from flax import nnx |
| 427 | + |
| 428 | + class Linear(nnx.Module): |
| 429 | + def __init__(self, din, dout, *, rngs): |
| 430 | + self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout))) |
| 431 | + self.b = nnx.Param(jnp.zeros((dout,))) |
| 432 | + |
| 433 | + def __call__(self, x): |
| 434 | + w_val = self.w |
| 435 | + b_val = self.b |
| 436 | + return x @ w_val + b_val |
| 437 | + |
| 438 | + N, dim = 3000, 3 |
| 439 | + data = random.normal(random.PRNGKey(0), (N, dim)) |
| 440 | + true_coefs = np.arange(1.0, dim + 1.0) |
| 441 | + logits = np.sum(true_coefs * data, axis=-1) |
| 442 | + labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) |
| 443 | + |
| 444 | + if callable_prior: |
| 445 | + |
| 446 | + def prior(name, shape): |
| 447 | + return dist.Cauchy() if name == "b" else dist.Normal() |
| 448 | + else: |
| 449 | + prior = {"w": dist.Normal(), "b": dist.Cauchy()} |
| 450 | + |
| 451 | + # Create a pre-initialized module for eager initialization |
| 452 | + rng_key = random.PRNGKey(0) |
| 453 | + linear_module = Linear(din=dim, dout=1, rngs=nnx.Rngs(params=rng_key)) |
| 454 | + |
| 455 | + # Extract parameters and state for inspection |
| 456 | + _, params_state = nnx.split(linear_module, nnx.Param) |
| 457 | + params_dict = nnx.to_pure_dict(params_state) |
| 458 | + |
| 459 | + # Verify parameters were created correctly |
| 460 | + assert "w" in params_dict |
| 461 | + assert "b" in params_dict |
| 462 | + assert params_dict["w"].shape == (dim, 1) |
| 463 | + assert params_dict["b"].shape == (1,) |
| 464 | + |
| 465 | + def model(data, labels=None): |
| 466 | + # Use the pre-initialized module with eager initialization |
| 467 | + nn = random_nnx_module("nn", linear_module, prior) |
| 468 | + logits = nn(data).squeeze(-1) |
| 469 | + return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) |
| 470 | + |
| 471 | + nuts_kernel = NUTS(model) |
| 472 | + mcmc = MCMC(nuts_kernel, num_warmup=2, num_samples=2, progress_bar=False) |
| 473 | + mcmc.run(random.PRNGKey(0), data, labels) |
| 474 | + samples = mcmc.get_samples() |
| 475 | + assert "nn/b" in samples |
| 476 | + assert "nn/w" in samples |
0 commit comments