|
| 1 | +# Copyright 2024 The JAX Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +from typing import Any, Callable, Union |
| 15 | + |
| 16 | +import jax |
| 17 | +import numpy as np |
| 18 | +from jax import numpy as jnp |
| 19 | +from jax import random as jax_api_random |
| 20 | +from jax._src import typing |
| 21 | +from jax._src.pallas.mosaic.primitives import prng_seed |
| 22 | +from jax._src.pallas.mosaic.primitives import prng_random_bits |
| 23 | +from jax._src import prng as jax_prng |
| 24 | + |
| 25 | + |
| 26 | +Shape = jax_prng.Shape |
| 27 | +FOLD_IN_ROUNDS = 128 |
| 28 | +SUPPORTED_CONVERSION_KEYS = ["rbg", "unsafe_rbg", "pallas"] |
| 29 | + |
| 30 | +def to_pallas_key(key: jax_prng.PRNGKeyArray) -> jax_prng.PRNGKeyArray: |
| 31 | + """Helper function for converting non-Pallas PRNG keys into Pallas keys.""" |
| 32 | + |
| 33 | + # Only allow conversion from RBG -> Pallas keys. |
| 34 | + # There is no technical reason why we cannot support Threefry here, but |
| 35 | + # this reduces the chance of unintended behavior where the pallas PRNG |
| 36 | + # produces different random bits than Threefry. RBG has fewer guarantees |
| 37 | + # so users of RBG should be more aware of the consequences. |
| 38 | + if key._impl.name not in SUPPORTED_CONVERSION_KEYS: |
| 39 | + raise ValueError(f"Unsupported key type: {key._impl.name}" |
| 40 | + f"Supported keys are: {SUPPORTED_CONVERSION_KEYS}") |
| 41 | + |
| 42 | + key_data = jax_api_random.key_data(key) |
| 43 | + pallas_key_size = np.prod(tpu_key_impl.key_shape) |
| 44 | + if key_data.size < pallas_key_size: |
| 45 | + raise ValueError(f"Key data must be at least {pallas_key_size} bytes.") |
| 46 | + pallas_key_data = jnp.ravel(key_data)[:pallas_key_size] |
| 47 | + pallas_key_data = jnp.reshape(pallas_key_data, tpu_key_impl.key_shape) |
| 48 | + return jax_api_random.wrap_key_data(pallas_key_data, impl='pallas') |
| 49 | + |
| 50 | +def _seed_func(seed: jnp.int32): |
| 51 | + seed_data = jnp.zeros(tpu_key_impl.key_shape, dtype=jnp.int32) |
| 52 | + return (seed_data + seed).astype(jnp.uint32) |
| 53 | + |
| 54 | +def _random_bits(key: typing.Array, bit_width: int, shape: Shape): |
| 55 | + if bit_width != 32: |
| 56 | + raise NotImplementedError("Bit width must be 32") |
| 57 | + if isinstance(key.dtype, jax_prng.KeyTy): |
| 58 | + key_data = jax.random.key_data(key) |
| 59 | + else: |
| 60 | + key_data = key |
| 61 | + prng_seed(key_data[0, 0]) |
| 62 | + return prng_random_bits(shape) |
| 63 | + |
| 64 | +def _fold_in(key: jax_prng.PRNGKeyArray, data: typing.Array): |
| 65 | + # Roughly, we compute the new key as follows: |
| 66 | + # new_key = random_bits(data)[..., 127] ^ random_bits(old_key)[..., 127] |
| 67 | + # Because the TPU generates random numbers in (8, 128) blocks at once, we |
| 68 | + # can generate that many values without additional cost which will reduce |
| 69 | + # correlation between the old and new keys. |
| 70 | + key_data = jax.random.key_data(key) |
| 71 | + |
| 72 | + prng_seed(data) |
| 73 | + data_bits = prng_random_bits( |
| 74 | + key_data.shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32) |
| 75 | + prng_seed(key_data[0, 0]) |
| 76 | + key_bits = prng_random_bits( |
| 77 | + key_data.shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32) |
| 78 | + |
| 79 | + mixed = key_bits[..., FOLD_IN_ROUNDS-1] ^ data_bits[..., FOLD_IN_ROUNDS-1] |
| 80 | + assert mixed.shape == key_data.shape |
| 81 | + impl: jax_prng.PRNGSpec = jax.random.key_impl(key) # type: ignore |
| 82 | + return jax.random.wrap_key_data(mixed, impl=impl) |
| 83 | + |
| 84 | +def _split(key: typing.Array, shape: Shape): |
| 85 | + del key, shape |
| 86 | + raise NotImplementedError() |
| 87 | + |
| 88 | +tpu_key_impl = jax_prng.PRNGImpl( |
| 89 | + # Use a 2D key since pallas only supports 2D tiling. |
| 90 | + key_shape=(1, 1), |
| 91 | + seed=_seed_func, |
| 92 | + split=_split, |
| 93 | + random_bits=_random_bits, |
| 94 | + fold_in=_fold_in, |
| 95 | + name="pallas", |
| 96 | + tag="pl" |
| 97 | +) |
| 98 | +jax_prng.register_prng(tpu_key_impl) |
| 99 | + |
| 100 | +# Implementation of the stateful Pallas PRNG API. |
| 101 | +# Users should set the seed using the `set_seed` function, |
| 102 | +# and call the appropriate stateful sampling functions. |
| 103 | +# The actual key impl should never be used. The impl |
| 104 | +# serves as internal boilerplate code because JAX's existing |
| 105 | +# random functions expect a key as an argument, and |
| 106 | +# the keys are only generated as part of unused arguments. |
| 107 | + |
| 108 | +def _pl_stateful_seed_func(seed: jnp.int32): |
| 109 | + del seed |
| 110 | + # Unused. Return the correct shape and dtype. |
| 111 | + return jnp.empty((), dtype=jnp.int32) |
| 112 | + |
| 113 | +def _pl_stateful_random_bits(key: typing.Array, bit_width: int, shape: Shape): |
| 114 | + del key |
| 115 | + assert bit_width == 32, "Bit width must be 32" |
| 116 | + return prng_random_bits(shape) |
| 117 | + |
| 118 | +def _pl_stateful_fold_in(key: typing.Array, data: typing.Array): |
| 119 | + del key, data |
| 120 | + raise NotImplementedError() |
| 121 | + |
| 122 | +def _pl_stateful_split(key: typing.Array, shape: Shape): |
| 123 | + del key, shape |
| 124 | + raise NotImplementedError() |
| 125 | + |
| 126 | + |
| 127 | +tpu_internal_stateful_impl = jax_prng.PRNGImpl( |
| 128 | + key_shape=(), |
| 129 | + seed=_pl_stateful_seed_func, |
| 130 | + split=_pl_stateful_split, |
| 131 | + random_bits=_pl_stateful_random_bits, |
| 132 | + fold_in=_pl_stateful_fold_in, |
| 133 | + name="_pallas_internal_stateful", |
| 134 | + tag="_pl_stateful" |
| 135 | +) |
| 136 | +jax_prng.register_prng(tpu_internal_stateful_impl) |
| 137 | + |
| 138 | +def set_seed(seed: Union[jnp.int32, jax.Array]): |
| 139 | + """Sets the seed for PRNG. |
| 140 | +
|
| 141 | + Args: |
| 142 | + seeds: An integer seed for setting the PRNG seed. |
| 143 | + """ |
| 144 | + if isinstance(seed, jax.Array): |
| 145 | + if seed.ndim != 1: |
| 146 | + raise ValueError("Seed data must be a scalar or 1D array") |
| 147 | + # TODO(justinfu): Mosaic currently only supports indexing by 0 |
| 148 | + # for scalar results when using vector.extract |
| 149 | + # After support is added, use all seed data. |
| 150 | + prng_seed(seed[0]) |
| 151 | + else: |
| 152 | + prng_seed(seed) |
| 153 | + |
| 154 | + |
| 155 | +SampleFnType = Any |
| 156 | +KeylessSampleFnType = Callable[..., jax.Array] |
| 157 | + |
| 158 | +def _make_stateful_sampler(sampler: SampleFnType) -> KeylessSampleFnType: |
| 159 | + """Converts a jax.random sampling function to a stateful version. |
| 160 | +
|
| 161 | + Args: |
| 162 | + sampler: A sampling function that consumes a key and returns |
| 163 | + random samples. |
| 164 | +
|
| 165 | + Returns: |
| 166 | + A stateful sampling function with the key argument removed. |
| 167 | + """ |
| 168 | + def new_sampler(*args, **kwargs): |
| 169 | + # Pass in a placeholder key into the sampling function. |
| 170 | + # The key is ignored by the stateful random_bits function, but all jax |
| 171 | + # sampling functions expect a key as input so we must pass one in here. |
| 172 | + placeholder_key = jax_prng.random_seed( |
| 173 | + None, impl=tpu_internal_stateful_impl) |
| 174 | + return sampler(placeholder_key, *args, **kwargs) |
| 175 | + # Remove key argument from docstring. |
| 176 | + doc_lines = filter( |
| 177 | + lambda line: 'key:' not in line, sampler.__doc__.split('\n')) |
| 178 | + new_sampler.__doc__ = '\n'.join(doc_lines) |
| 179 | + return new_sampler |
| 180 | + |
| 181 | +bits = _make_stateful_sampler(jax_api_random.bits) |
| 182 | +uniform = _make_stateful_sampler(jax_api_random.uniform) |
| 183 | +bernoulli = _make_stateful_sampler(jax_api_random.bernoulli) |
0 commit comments