diff --git a/docs_nnx/guides/randomness.ipynb b/docs_nnx/guides/randomness.ipynb index e6a8e4f82..8acc37e2a 100644 --- a/docs_nnx/guides/randomness.ipynb +++ b/docs_nnx/guides/randomness.ipynb @@ -15,6 +15,7 @@ "metadata": {}, "outputs": [], "source": [ + "import jax\n", "from flax import nnx\n", "\n", "class Model(nnx.Module):\n", @@ -70,7 +71,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -82,7 +83,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -93,7 +94,7 @@ } ], "source": [ - "rngs = nnx.Rngs(params=0, dropout=random.key(1))\n", + "rngs = nnx.Rngs(params=0, dropout=jax.random.key(1))\n", "nnx.display(rngs)" ] }, @@ -114,7 +115,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -126,7 +127,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -249,7 +250,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -261,7 +262,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -321,14 +322,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Forking random state" + "## Split random state" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Say you want to train a model that uses dropout on a batch of data. You don't want to use the same random state for every dropout mask in your batch. Instead, you want to fork the random state into separate pieces for each layer. This can be accomplished with the `fork` method, as shown below." + "Say you want to train a model that uses dropout on a batch of data. You don't want to use the same random state for every dropout mask in your batch. Instead, you want to split the random state into separate pieces for each layer. This can be accomplished with the `split` method, as shown below." ] }, { @@ -410,8 +411,8 @@ ], "source": [ "dropout_rngs = nnx.Rngs(1)\n", - "forked_rngs = dropout_rngs.fork(split=5)\n", - "(dropout_rngs, forked_rngs)" + "split_rngs = dropout_rngs.split(5)\n", + "(dropout_rngs, split_rngs)" ] }, { @@ -431,14 +432,14 @@ } ], "source": [ - "model_forward(model, jnp.ones((5, 20)), forked_rngs).shape" + "model_forward(model, jnp.ones((5, 20)), split_rngs).shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The output of `rng.fork` is another `Rng` with keys and counts that have an expanded shape. In the example above, the `RngKey` and `RngCount` of `dropout_rngs` have shape `()`, but in `forked_rngs` they have shape `(5,)`." + "The output of `rng.split` is another `Rng` with keys and counts that have an expanded shape. In the example above, the `RngKey` and `RngCount` of `dropout_rngs` have shape `()`, but in `split_rngs` they have shape `(5,)`." ] }, { @@ -457,7 +458,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -501,13 +502,13 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -519,7 +520,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -560,7 +561,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -581,14 +582,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Forking implicit random state\n", + "## Splitting implicit random state\n", "\n", - "We saw above how to use `rng.fork` when passing explicit random state through [Flax NNX transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) like `nnx.vmap` or `nnx.pmap`. The decorator `nnx.fork_rngs` allows this for implicit random state. Consider the example below, which generates a batch of samples from the nondeterministic model we defined above." + "We saw above how to use `rng.split` when passing explicit random state through [Flax NNX transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) like `nnx.vmap` or `nnx.pmap`. We can do the same for implicit random state using `nnx.split_rngs`. Consider the example below, which generates a batch of samples from the nondeterministic model we defined above." ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 28, "metadata": {}, "outputs": [ { @@ -600,12 +601,14 @@ } ], "source": [ - "rng_axes = nnx.StateAxes({'dropout': 0, ...: None})\n", - "\n", - "@nnx.fork_rngs(split={'dropout': 5})\n", - "@nnx.vmap(in_axes=(rng_axes, None), out_axes=0)\n", + "@nnx.split_rngs(splits=5, only='dropout')\n", "def sample_from_model(model, x):\n", - " return model(x)\n", + " graphdef, dropout, others = nnx.split(model, 'dropout', ...)\n", + " @nnx.vmap(in_axes=(0, None), out_axes=0)\n", + " def f(dropout, others):\n", + " model = nnx.merge(graphdef, dropout, others)\n", + " return model(x)\n", + " return f(dropout, others)\n", "\n", "print(sample_from_model(model, x).shape)" ] @@ -614,9 +617,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Here `sample_from_model` is modified by two decorators:\n", - "- The function we get from the `nnx.vmap` decorator expects that the random state of the `model` argument has already been split into 5 pieces. It runs the model once for each random key.\n", - "- The function we get from the `nnx.fork_rngs` decorator splits the random state of its `model` argument into five pieces before passing it on to the inner function." + "The call to the decorator `split_rngs` makes `sample_from_model` split the rngs of the `model` argument before the function body is applied. \n", + "Within the function body, we need to separate out the dropout RNG state from other implicit state so that we can `vmap` its axes alone. To do this, we `nnx.split` the model, using \"dropout\" as our filter. Finally, we can vmap the inner function on these split model components." ] }, { @@ -642,7 +644,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 47, "metadata": {}, "outputs": [], "source": [ @@ -658,7 +660,7 @@ " def __call__(self, h, x) -> tuple[jax.Array, jax.Array]:\n", " h = self.drop(h) # Recurrent dropout.\n", " y = nnx.relu(self.linear(jnp.concatenate([h, x], axis=-1)))\n", - " self.count.value += 1\n", + " self.count[...] = self.count[...] + 1\n", " return y, y\n", "\n", " def initial_state(self, batch_size: int):\n", @@ -678,7 +680,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 50, "metadata": {}, "outputs": [ { @@ -686,30 +688,30 @@ "output_type": "stream", "text": [ "y.shape = (4, 20, 16)\n", - "cell.count.value = Array(20, dtype=uint32)\n" + "cell.count[...] = Array(20, dtype=uint32)\n" ] } ], "source": [ "@nnx.jit\n", "def rnn_forward(cell: RNNCell, x: jax.Array):\n", - " h = cell.initial_state(batch_size=x.shape[0])\n", - "\n", - " # Broadcast the 'recurrent_dropout' PRNG state to have the same mask on every step.\n", - " state_axes = nnx.StateAxes({'recurrent_dropout': None, ...: nnx.Carry})\n", - " @nnx.scan(in_axes=(state_axes, nnx.Carry, 1), out_axes=(nnx.Carry, 1))\n", - " def unroll(cell: RNNCell, h, x) -> tuple[jax.Array, jax.Array]:\n", - " h, y = cell(h, x)\n", - " return h, y\n", - "\n", - " h, y = unroll(cell, h, x)\n", - " return y\n", + " h = cell.initial_state(batch_size=x.shape[0])\n", + " graphdef, rngs, state = nnx.split(cell, 'recurrent_dropout', ...)\n", + " # Broadcast the 'recurrent_dropout' PRNG state to have the same mask on every step.\n", + " @nnx.scan(in_axes=(None, nnx.Carry, 1), out_axes=(nnx.Carry, 1))\n", + " def unroll(rngs, mutable, x) -> tuple[jax.Array, jax.Array]:\n", + " state, h = mutable\n", + " cell = nnx.merge(graphdef, rngs, state)\n", + " h, y = cell(h, x)\n", + " return (state, h), y\n", + " _, y = unroll(rngs, (state, h), x)\n", + " return y\n", "\n", "x = jnp.ones((4, 20, 8))\n", "y = rnn_forward(cell, x)\n", "\n", "print(f'{y.shape = }')\n", - "print(f'{cell.count.value = }')" + "print(f'{cell.count[...] = }')" ] } ], @@ -727,7 +729,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.6" + "version": "3.13.5" } }, "nbformat": 4, diff --git a/docs_nnx/guides/randomness.md b/docs_nnx/guides/randomness.md index c7955c5ba..8ac4769db 100644 --- a/docs_nnx/guides/randomness.md +++ b/docs_nnx/guides/randomness.md @@ -13,6 +13,7 @@ jupytext: Flax NNX uses the stateful `nnx.Rngs` class to simplify Jax's handling of random states. For example, the code below uses a `nnx.Rngs` object to define a simple linear model with dropout: ```{code-cell} ipython3 +import jax from flax import nnx class Model(nnx.Module): @@ -56,7 +57,7 @@ To create an `nnx.Rngs` object you can simply pass an integer seed or `jax.rando Here's an example: ```{code-cell} ipython3 -rngs = nnx.Rngs(params=0, dropout=random.key(1)) +rngs = nnx.Rngs(params=0, dropout=jax.random.key(1)) nnx.display(rngs) ``` @@ -149,11 +150,11 @@ z1 = rngs.normal((2, 3)) # generates key from rngs.default z2 = rngs.params.bernoulli(0.5, (10,)) # generates key from rngs.params ``` -## Forking random state +## Split random state +++ -Say you want to train a model that uses dropout on a batch of data. You don't want to use the same random state for every dropout mask in your batch. Instead, you want to fork the random state into separate pieces for each layer. This can be accomplished with the `fork` method, as shown below. +Say you want to train a model that uses dropout on a batch of data. You don't want to use the same random state for every dropout mask in your batch. Instead, you want to split the random state into separate pieces for each layer. This can be accomplished with the `split` method, as shown below. ```{code-cell} ipython3 class Model(nnx.Module): @@ -177,15 +178,15 @@ def model_forward(model, x, rngs): ```{code-cell} ipython3 dropout_rngs = nnx.Rngs(1) -forked_rngs = dropout_rngs.fork(split=5) -(dropout_rngs, forked_rngs) +split_rngs = dropout_rngs.split(5) +(dropout_rngs, split_rngs) ``` ```{code-cell} ipython3 -model_forward(model, jnp.ones((5, 20)), forked_rngs).shape +model_forward(model, jnp.ones((5, 20)), split_rngs).shape ``` -The output of `rng.fork` is another `Rng` with keys and counts that have an expanded shape. In the example above, the `RngKey` and `RngCount` of `dropout_rngs` have shape `()`, but in `forked_rngs` they have shape `(5,)`. +The output of `rng.split` is another `Rng` with keys and counts that have an expanded shape. In the example above, the `RngKey` and `RngCount` of `dropout_rngs` have shape `()`, but in `split_rngs` they have shape `(5,)`. +++ @@ -256,24 +257,25 @@ assert not jnp.allclose(y1, y2) # different assert jnp.allclose(y1, y3) # same ``` -## Forking implicit random state +## Splitting implicit random state -We saw above how to use `rng.fork` when passing explicit random state through [Flax NNX transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) like `nnx.vmap` or `nnx.pmap`. The decorator `nnx.fork_rngs` allows this for implicit random state. Consider the example below, which generates a batch of samples from the nondeterministic model we defined above. +We saw above how to use `rng.split` when passing explicit random state through [Flax NNX transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) like `nnx.vmap` or `nnx.pmap`. We can do the same for implicit random state using `nnx.split_rngs`. Consider the example below, which generates a batch of samples from the nondeterministic model we defined above. ```{code-cell} ipython3 -rng_axes = nnx.StateAxes({'dropout': 0, ...: None}) - -@nnx.fork_rngs(split={'dropout': 5}) -@nnx.vmap(in_axes=(rng_axes, None), out_axes=0) +@nnx.split_rngs(splits=5, only='dropout') def sample_from_model(model, x): - return model(x) + graphdef, dropout, others = nnx.split(model, 'dropout', ...) + @nnx.vmap(in_axes=(0, None), out_axes=0) + def f(dropout, others): + model = nnx.merge(graphdef, dropout, others) + return model(x) + return f(dropout, others) print(sample_from_model(model, x).shape) ``` -Here `sample_from_model` is modified by two decorators: -- The function we get from the `nnx.vmap` decorator expects that the random state of the `model` argument has already been split into 5 pieces. It runs the model once for each random key. -- The function we get from the `nnx.fork_rngs` decorator splits the random state of its `model` argument into five pieces before passing it on to the inner function. +The call to the decorator `split_rngs` makes `sample_from_model` split the rngs of the `model` argument before the function body is applied. +Within the function body, we need to separate out the dropout RNG state from other implicit state so that we can `vmap` its axes alone. To do this, we `nnx.split` the model, using "dropout" as our filter. Finally, we can vmap the inner function on these split model components. +++ @@ -303,7 +305,7 @@ class RNNCell(nnx.Module): def __call__(self, h, x) -> tuple[jax.Array, jax.Array]: h = self.drop(h) # Recurrent dropout. y = nnx.relu(self.linear(jnp.concatenate([h, x], axis=-1))) - self.count.value += 1 + self.count[...] = self.count[...] + 1 return y, y def initial_state(self, batch_size: int): @@ -319,21 +321,21 @@ Next, use `nnx.scan` over an `unroll` function to implement the `rnn_forward` op ```{code-cell} ipython3 @nnx.jit def rnn_forward(cell: RNNCell, x: jax.Array): - h = cell.initial_state(batch_size=x.shape[0]) - - # Broadcast the 'recurrent_dropout' PRNG state to have the same mask on every step. - state_axes = nnx.StateAxes({'recurrent_dropout': None, ...: nnx.Carry}) - @nnx.scan(in_axes=(state_axes, nnx.Carry, 1), out_axes=(nnx.Carry, 1)) - def unroll(cell: RNNCell, h, x) -> tuple[jax.Array, jax.Array]: - h, y = cell(h, x) - return h, y - - h, y = unroll(cell, h, x) - return y + h = cell.initial_state(batch_size=x.shape[0]) + graphdef, rngs, state = nnx.split(cell, 'recurrent_dropout', ...) + # Broadcast the 'recurrent_dropout' PRNG state to have the same mask on every step. + @nnx.scan(in_axes=(None, nnx.Carry, 1), out_axes=(nnx.Carry, 1)) + def unroll(rngs, mutable, x) -> tuple[jax.Array, jax.Array]: + state, h = mutable + cell = nnx.merge(graphdef, rngs, state) + h, y = cell(h, x) + return (state, h), y + _, y = unroll(rngs, (state, h), x) + return y x = jnp.ones((4, 20, 8)) y = rnn_forward(cell, x) print(f'{y.shape = }') -print(f'{cell.count.value = }') +print(f'{cell.count[...] = }') ``` diff --git a/flax/nnx/rnglib.py b/flax/nnx/rnglib.py index 6e7d0c9db..960f6299a 100644 --- a/flax/nnx/rnglib.py +++ b/flax/nnx/rnglib.py @@ -438,7 +438,7 @@ def items(self): def split(self, k: tp.Mapping[filterlib.Filter, int | tuple[int, ...]] | int | tuple[int, ...]): """ - Splits the keys of the newly created ``Rngs`` object. + Splits the keys of a ``Rngs`` object. Example:: >>> from flax import nnx