|
| 1 | +# Copyright 2024 The JAX Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""Implementation of the Philox PRNG as a Pallas kernel.""" |
| 15 | +from typing import Sequence |
| 16 | +import jax |
| 17 | +from jax import typing |
| 18 | +from jax._src import prng |
| 19 | +from jax.experimental import pallas as pl |
| 20 | +from jax.experimental.pallas import tpu as pltpu |
| 21 | +import jax.numpy as jnp |
| 22 | +import numpy as np |
| 23 | +from jax.experimental.pallas.ops.tpu.random import prng_utils |
| 24 | + |
| 25 | +Shape = Sequence[int] |
| 26 | + |
| 27 | +BLOCK_SIZE = (256, 256) |
| 28 | + |
| 29 | +# Philox constants. See original paper at: |
| 30 | +# "Parallel Random Numbers: As Easy as 1, 2, 3", Salmon et. al. 2011 |
| 31 | +K_HI_32 = 0x9E3779B9 |
| 32 | +K_LO_32 = 0xBB67AE85 |
| 33 | +MUL_A = 0xCD9E8D57 |
| 34 | +MUL_B = 0xD2511F53 |
| 35 | + |
| 36 | + |
| 37 | +def mul32_hi_lo(x: jax.Array, y: jax.Array) -> tuple[jax.Array, jax.Array]: |
| 38 | + """Multiplies 2 32-bit values and returns the hi+low bits of the result.""" |
| 39 | + xhi = x >> 16 |
| 40 | + yhi = y >> 16 |
| 41 | + xlo = x & 0xffff |
| 42 | + ylo = y & 0xffff |
| 43 | + |
| 44 | + xy_hi = xhi * yhi |
| 45 | + xy_lo = xlo * ylo |
| 46 | + cross_xy = xhi * ylo |
| 47 | + cross_yx = xlo * yhi |
| 48 | + carry = (cross_xy & 0xffff) + (cross_yx & 0xffff) + (xy_lo >> 16) |
| 49 | + return xy_hi + (cross_xy >> 16) + (cross_yx >> 16) + (carry >> 16), xy_lo |
| 50 | + |
| 51 | + |
| 52 | +def philox_4x32(hi0, lo0, hi1, lo1, k_hi, k_lo, rounds = 10): |
| 53 | + """Philox 4x32 keyed hash function.""" |
| 54 | + k_hi_const = jnp.array(K_HI_32, dtype=jnp.uint32) |
| 55 | + k_lo_const = jnp.array(K_LO_32, dtype=jnp.uint32) |
| 56 | + mul_a = jnp.array(MUL_A, dtype=jnp.uint32) |
| 57 | + mul_b = jnp.array(MUL_B, dtype=jnp.uint32) |
| 58 | + |
| 59 | + for i in range(rounds): |
| 60 | + # Compute the round. |
| 61 | + new_hi0, new_lo0 = mul32_hi_lo(mul_a, hi1) |
| 62 | + new_hi0 = new_hi0 ^ lo0 ^ k_hi |
| 63 | + new_hi1, new_lo1 = mul32_hi_lo(mul_b, hi0) |
| 64 | + new_hi1 = new_hi1 ^ lo1 ^ k_lo |
| 65 | + hi0, lo0, hi1, lo1 = new_hi0, new_lo0, new_hi1, new_lo1 |
| 66 | + |
| 67 | + # Raise the key on all iterations except for the last round. |
| 68 | + if i != rounds - 1: |
| 69 | + k_hi = k_hi + k_hi_const |
| 70 | + k_lo = k_lo + k_lo_const |
| 71 | + return hi0, lo0, hi1, lo1 |
| 72 | + |
| 73 | + |
| 74 | +def philox_4x32_kernel(key, |
| 75 | + shape: Shape, |
| 76 | + unpadded_shape: Shape, |
| 77 | + block_size: tuple[int, int], |
| 78 | + offset: typing.ArrayLike = 0, |
| 79 | + fuse_output: bool = True): |
| 80 | + """Generates random bits using the Philox keyed hash function. |
| 81 | +
|
| 82 | + Args: |
| 83 | + key: A Philox key of shape (2,). |
| 84 | + shape: The shape of the output. Must be divisible by `block_size`. |
| 85 | + unpadded_shape: If `shape` is padded, then this is the shape of the |
| 86 | + output tensor if it were not padded. This is important for indexing |
| 87 | + calculations within the kernel. If `shape` is not padded, then this |
| 88 | + should be equal to `shape`. |
| 89 | + block_size: The block size of the kernel. |
| 90 | + offset: An optional offset to the counts. |
| 91 | + fuse_output: Whether to fuse the output bits into a single value. |
| 92 | +
|
| 93 | + Returns: |
| 94 | + A tensor of random bits of shape `shape` if fuse_output=True. Otherwise, |
| 95 | + this will return a tensor of shape (2, *shape) with the first channel being |
| 96 | + the high bits and the second channel being the low bits. |
| 97 | + """ |
| 98 | + shape = tuple(shape) |
| 99 | + if np.prod(shape) > jnp.iinfo(jnp.uint32).max: |
| 100 | + raise ValueError( |
| 101 | + f"Shape too large: {np.prod(shape)} > {np.iinfo(jnp.uint32).max}") |
| 102 | + |
| 103 | + if (shape[-2] % block_size[-2] != 0) or (shape[-1] % block_size[-1] != 0): |
| 104 | + raise ValueError( |
| 105 | + f"Shape dimension {shape[-2:]} must be divisible by {block_size}") |
| 106 | + grid_dims = shape[:-2] + ( |
| 107 | + shape[-2] // block_size[-2], shape[-1] // block_size[1],) |
| 108 | + offset = jnp.array(offset, dtype=jnp.uint32) |
| 109 | + if offset.ndim != 0: |
| 110 | + raise ValueError(f"Offset must be scalar, got {offset.shape}") |
| 111 | + offset = jnp.reshape(offset, (1,)) |
| 112 | + |
| 113 | + def kernel(offset_ref, key_ref, out_ref): |
| 114 | + counts_idx = tuple(pl.program_id(i) for i in range(len(grid_dims))) |
| 115 | + offset = prng_utils.compute_scalar_offset( |
| 116 | + counts_idx, unpadded_shape, block_shape) |
| 117 | + counts_lo = prng_utils.blocked_iota(block_size, unpadded_shape) |
| 118 | + counts_lo = counts_lo + offset + offset_ref[0] |
| 119 | + counts_lo = counts_lo.astype(jnp.uint32) |
| 120 | + # TODO(justinfu): Support hi bits on count. |
| 121 | + _zeros = jnp.zeros_like(counts_lo) |
| 122 | + k1 = jnp.reshape(key_ref[0, 0], (1, 1)) |
| 123 | + k2 = jnp.reshape(key_ref[0, 1], (1, 1)) |
| 124 | + o1, o2, _, _ = philox_4x32(_zeros, counts_lo, _zeros, _zeros, k1, k2) |
| 125 | + if fuse_output: |
| 126 | + out_bits = o1 ^ o2 |
| 127 | + out_ref[...] = out_bits.reshape(out_ref.shape) |
| 128 | + else: |
| 129 | + out_ref[0, ...] = o1.reshape(out_ref[0].shape) |
| 130 | + out_ref[1, ...] = o2.reshape(out_ref[0].shape) |
| 131 | + |
| 132 | + key = key.reshape((1, 2)) |
| 133 | + block_shape = (1,) * (len(shape)-2) + block_size |
| 134 | + if fuse_output: |
| 135 | + out = jax.ShapeDtypeStruct(shape, dtype=jnp.uint32) |
| 136 | + out_spec = pl.BlockSpec(block_shape, lambda *idxs: idxs) |
| 137 | + else: |
| 138 | + out = jax.ShapeDtypeStruct((2,) + shape, dtype=jnp.uint32) |
| 139 | + out_spec = pl.BlockSpec((2,) + block_shape, lambda *idxs: (0, *idxs)) |
| 140 | + return pl.pallas_call( |
| 141 | + kernel, |
| 142 | + in_specs=[ |
| 143 | + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), |
| 144 | + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), |
| 145 | + ], |
| 146 | + out_specs=out_spec, |
| 147 | + grid=grid_dims, |
| 148 | + out_shape=out, |
| 149 | + )(offset, key) |
| 150 | + |
| 151 | + |
| 152 | +def philox_4x32_count(key, |
| 153 | + shape: Shape, |
| 154 | + offset: typing.ArrayLike = 0, |
| 155 | + fuse_output: bool = True): |
| 156 | + """Convenience function to call philox_4x32_kernel with padded shapes.""" |
| 157 | + if len(shape) == 0: |
| 158 | + return philox_4x32_count( |
| 159 | + key, (1, 1), offset=offset, fuse_output=fuse_output)[..., 0, 0] |
| 160 | + elif len(shape) == 1: |
| 161 | + return philox_4x32_count( |
| 162 | + key, (1, *shape), offset=offset, fuse_output=fuse_output)[..., 0, :] |
| 163 | + |
| 164 | + requires_pad = ( |
| 165 | + shape[-2] % BLOCK_SIZE[-2] != 0) or (shape[-1] % BLOCK_SIZE[-1] != 0) |
| 166 | + if requires_pad: |
| 167 | + padded_shape = tuple(shape[:-2]) + ( |
| 168 | + prng_utils.round_up(shape[-2], BLOCK_SIZE[-2]), |
| 169 | + prng_utils.round_up(shape[-1], BLOCK_SIZE[-1]), |
| 170 | + ) |
| 171 | + padded_result = philox_4x32_kernel( |
| 172 | + key, padded_shape, shape, |
| 173 | + block_size=BLOCK_SIZE, offset=offset, |
| 174 | + fuse_output=fuse_output) |
| 175 | + return padded_result[..., :shape[-2], :shape[-1]] |
| 176 | + else: |
| 177 | + return philox_4x32_kernel(key, shape, shape, |
| 178 | + block_size=BLOCK_SIZE, offset=offset, |
| 179 | + fuse_output=fuse_output) |
| 180 | + |
| 181 | + |
| 182 | +def philox_split(key, shape: Shape): |
| 183 | + """Splits the key into two keys of the same shape.""" |
| 184 | + bits1, bits2 = philox_4x32_count(key, shape, fuse_output=False) |
| 185 | + return jnp.stack([bits1, bits2], axis=bits1.ndim) |
| 186 | + |
| 187 | + |
| 188 | +def philox_random_bits(key, bit_width: int, shape: Shape): |
| 189 | + if bit_width != 32: |
| 190 | + raise ValueError("Only 32-bit PRNG supported.") |
| 191 | + return philox_4x32_count(key, shape, fuse_output=True) |
| 192 | + |
| 193 | + |
| 194 | +def philox_fold_in(key, data): |
| 195 | + assert data.ndim == 0 |
| 196 | + return philox_4x32_count(key, (), offset=data, fuse_output=False) |
| 197 | + |
| 198 | + |
| 199 | +plphilox_prng_impl = prng.PRNGImpl( |
| 200 | + key_shape=(2,), |
| 201 | + seed=prng.threefry_seed, |
| 202 | + split=philox_split, |
| 203 | + random_bits=philox_random_bits, |
| 204 | + fold_in=philox_fold_in, |
| 205 | + name="pallas_philox4x32", |
| 206 | + tag="pllox") |
| 207 | + |
| 208 | +prng.register_prng(plphilox_prng_impl) |
0 commit comments