Skip to content

Commit 7e96914

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
Add Pallas Philox implementation.
Implemented in the same style as the threefry kernel. Philox is roughly 2x faster than the existing JAX Threefry implementation in both runtime and compile time. PiperOrigin-RevId: 707276043
1 parent d4031e9 commit 7e96914

File tree

4 files changed

+321
-42
lines changed

4 files changed

+321
-42
lines changed
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Implementation of the Philox PRNG as a Pallas kernel."""
15+
from typing import Sequence
16+
import jax
17+
from jax import typing
18+
from jax._src import prng
19+
from jax.experimental import pallas as pl
20+
from jax.experimental.pallas import tpu as pltpu
21+
import jax.numpy as jnp
22+
import numpy as np
23+
from jax.experimental.pallas.ops.tpu.random import prng_utils
24+
25+
Shape = Sequence[int]
26+
27+
BLOCK_SIZE = (256, 256)
28+
29+
# Philox constants. See original paper at:
30+
# "Parallel Random Numbers: As Easy as 1, 2, 3", Salmon et. al. 2011
31+
K_HI_32 = 0x9E3779B9
32+
K_LO_32 = 0xBB67AE85
33+
MUL_A = 0xCD9E8D57
34+
MUL_B = 0xD2511F53
35+
36+
37+
def mul32_hi_lo(x: jax.Array, y: jax.Array) -> tuple[jax.Array, jax.Array]:
38+
"""Multiplies 2 32-bit values and returns the hi+low bits of the result."""
39+
xhi = x >> 16
40+
yhi = y >> 16
41+
xlo = x & 0xffff
42+
ylo = y & 0xffff
43+
44+
xy_hi = xhi * yhi
45+
xy_lo = xlo * ylo
46+
cross_xy = xhi * ylo
47+
cross_yx = xlo * yhi
48+
carry = (cross_xy & 0xffff) + (cross_yx & 0xffff) + (xy_lo >> 16)
49+
return xy_hi + (cross_xy >> 16) + (cross_yx >> 16) + (carry >> 16), xy_lo
50+
51+
52+
def philox_4x32(hi0, lo0, hi1, lo1, k_hi, k_lo, rounds = 10):
53+
"""Philox 4x32 keyed hash function."""
54+
k_hi_const = jnp.array(K_HI_32, dtype=jnp.uint32)
55+
k_lo_const = jnp.array(K_LO_32, dtype=jnp.uint32)
56+
mul_a = jnp.array(MUL_A, dtype=jnp.uint32)
57+
mul_b = jnp.array(MUL_B, dtype=jnp.uint32)
58+
59+
for i in range(rounds):
60+
# Compute the round.
61+
new_hi0, new_lo0 = mul32_hi_lo(mul_a, hi1)
62+
new_hi0 = new_hi0 ^ lo0 ^ k_hi
63+
new_hi1, new_lo1 = mul32_hi_lo(mul_b, hi0)
64+
new_hi1 = new_hi1 ^ lo1 ^ k_lo
65+
hi0, lo0, hi1, lo1 = new_hi0, new_lo0, new_hi1, new_lo1
66+
67+
# Raise the key on all iterations except for the last round.
68+
if i != rounds - 1:
69+
k_hi = k_hi + k_hi_const
70+
k_lo = k_lo + k_lo_const
71+
return hi0, lo0, hi1, lo1
72+
73+
74+
def philox_4x32_kernel(key,
75+
shape: Shape,
76+
unpadded_shape: Shape,
77+
block_size: tuple[int, int],
78+
offset: typing.ArrayLike = 0,
79+
fuse_output: bool = True):
80+
"""Generates random bits using the Philox keyed hash function.
81+
82+
Args:
83+
key: A Philox key of shape (2,).
84+
shape: The shape of the output. Must be divisible by `block_size`.
85+
unpadded_shape: If `shape` is padded, then this is the shape of the
86+
output tensor if it were not padded. This is important for indexing
87+
calculations within the kernel. If `shape` is not padded, then this
88+
should be equal to `shape`.
89+
block_size: The block size of the kernel.
90+
offset: An optional offset to the counts.
91+
fuse_output: Whether to fuse the output bits into a single value.
92+
93+
Returns:
94+
A tensor of random bits of shape `shape` if fuse_output=True. Otherwise,
95+
this will return a tensor of shape (2, *shape) with the first channel being
96+
the high bits and the second channel being the low bits.
97+
"""
98+
shape = tuple(shape)
99+
if np.prod(shape) > jnp.iinfo(jnp.uint32).max:
100+
raise ValueError(
101+
f"Shape too large: {np.prod(shape)} > {np.iinfo(jnp.uint32).max}")
102+
103+
if (shape[-2] % block_size[-2] != 0) or (shape[-1] % block_size[-1] != 0):
104+
raise ValueError(
105+
f"Shape dimension {shape[-2:]} must be divisible by {block_size}")
106+
grid_dims = shape[:-2] + (
107+
shape[-2] // block_size[-2], shape[-1] // block_size[1],)
108+
offset = jnp.array(offset, dtype=jnp.uint32)
109+
if offset.ndim != 0:
110+
raise ValueError(f"Offset must be scalar, got {offset.shape}")
111+
offset = jnp.reshape(offset, (1,))
112+
113+
def kernel(offset_ref, key_ref, out_ref):
114+
counts_idx = tuple(pl.program_id(i) for i in range(len(grid_dims)))
115+
offset = prng_utils.compute_scalar_offset(
116+
counts_idx, unpadded_shape, block_shape)
117+
counts_lo = prng_utils.blocked_iota(block_size, unpadded_shape)
118+
counts_lo = counts_lo + offset + offset_ref[0]
119+
counts_lo = counts_lo.astype(jnp.uint32)
120+
# TODO(justinfu): Support hi bits on count.
121+
_zeros = jnp.zeros_like(counts_lo)
122+
k1 = jnp.reshape(key_ref[0, 0], (1, 1))
123+
k2 = jnp.reshape(key_ref[0, 1], (1, 1))
124+
o1, o2, _, _ = philox_4x32(_zeros, counts_lo, _zeros, _zeros, k1, k2)
125+
if fuse_output:
126+
out_bits = o1 ^ o2
127+
out_ref[...] = out_bits.reshape(out_ref.shape)
128+
else:
129+
out_ref[0, ...] = o1.reshape(out_ref[0].shape)
130+
out_ref[1, ...] = o2.reshape(out_ref[0].shape)
131+
132+
key = key.reshape((1, 2))
133+
block_shape = (1,) * (len(shape)-2) + block_size
134+
if fuse_output:
135+
out = jax.ShapeDtypeStruct(shape, dtype=jnp.uint32)
136+
out_spec = pl.BlockSpec(block_shape, lambda *idxs: idxs)
137+
else:
138+
out = jax.ShapeDtypeStruct((2,) + shape, dtype=jnp.uint32)
139+
out_spec = pl.BlockSpec((2,) + block_shape, lambda *idxs: (0, *idxs))
140+
return pl.pallas_call(
141+
kernel,
142+
in_specs=[
143+
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
144+
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
145+
],
146+
out_specs=out_spec,
147+
grid=grid_dims,
148+
out_shape=out,
149+
)(offset, key)
150+
151+
152+
def philox_4x32_count(key,
153+
shape: Shape,
154+
offset: typing.ArrayLike = 0,
155+
fuse_output: bool = True):
156+
"""Convenience function to call philox_4x32_kernel with padded shapes."""
157+
if len(shape) == 0:
158+
return philox_4x32_count(
159+
key, (1, 1), offset=offset, fuse_output=fuse_output)[..., 0, 0]
160+
elif len(shape) == 1:
161+
return philox_4x32_count(
162+
key, (1, *shape), offset=offset, fuse_output=fuse_output)[..., 0, :]
163+
164+
requires_pad = (
165+
shape[-2] % BLOCK_SIZE[-2] != 0) or (shape[-1] % BLOCK_SIZE[-1] != 0)
166+
if requires_pad:
167+
padded_shape = tuple(shape[:-2]) + (
168+
prng_utils.round_up(shape[-2], BLOCK_SIZE[-2]),
169+
prng_utils.round_up(shape[-1], BLOCK_SIZE[-1]),
170+
)
171+
padded_result = philox_4x32_kernel(
172+
key, padded_shape, shape,
173+
block_size=BLOCK_SIZE, offset=offset,
174+
fuse_output=fuse_output)
175+
return padded_result[..., :shape[-2], :shape[-1]]
176+
else:
177+
return philox_4x32_kernel(key, shape, shape,
178+
block_size=BLOCK_SIZE, offset=offset,
179+
fuse_output=fuse_output)
180+
181+
182+
def philox_split(key, shape: Shape):
183+
"""Splits the key into two keys of the same shape."""
184+
bits1, bits2 = philox_4x32_count(key, shape, fuse_output=False)
185+
return jnp.stack([bits1, bits2], axis=bits1.ndim)
186+
187+
188+
def philox_random_bits(key, bit_width: int, shape: Shape):
189+
if bit_width != 32:
190+
raise ValueError("Only 32-bit PRNG supported.")
191+
return philox_4x32_count(key, shape, fuse_output=True)
192+
193+
194+
def philox_fold_in(key, data):
195+
assert data.ndim == 0
196+
return philox_4x32_count(key, (), offset=data, fuse_output=False)
197+
198+
199+
plphilox_prng_impl = prng.PRNGImpl(
200+
key_shape=(2,),
201+
seed=prng.threefry_seed,
202+
split=philox_split,
203+
random_bits=philox_random_bits,
204+
fold_in=philox_fold_in,
205+
name="pallas_philox4x32",
206+
tag="pllox")
207+
208+
prng.register_prng(plphilox_prng_impl)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Helper functions for PRNG kernels."""
15+
from typing import Sequence
16+
from jax import lax
17+
import jax.numpy as jnp
18+
19+
Shape = Sequence[int]
20+
21+
round_up = lambda x, y: (x + y - 1) // y * y
22+
23+
def blocked_iota(block_shape: Shape,
24+
total_shape: Shape):
25+
"""Computes a sub-block of a larger shaped iota.
26+
27+
Args:
28+
block_shape: The output block shape of the iota.
29+
total_shape: The total shape of the input tensor.
30+
Returns:
31+
Result of the blocked iota.
32+
"""
33+
iota_data = jnp.zeros(block_shape, dtype=jnp.uint32)
34+
multiplier = 1
35+
for dim in range(len(block_shape)-1, -1, -1):
36+
block_mult = 1
37+
counts_lo = lax.broadcasted_iota(
38+
dtype=jnp.uint32, shape=block_shape, dimension=dim
39+
)
40+
iota_data += counts_lo * multiplier * block_mult
41+
multiplier *= total_shape[dim]
42+
return iota_data
43+
44+
45+
def compute_scalar_offset(iteration_index,
46+
total_size: Shape,
47+
block_size: Shape):
48+
ndims = len(iteration_index)
49+
dim_size = 1
50+
total_idx = 0
51+
for i in range(ndims-1, -1, -1):
52+
dim_idx = iteration_index[i] * block_size[i]
53+
total_idx += dim_idx * dim_size
54+
dim_size *= total_size[i]
55+
return total_idx

jax/experimental/pallas/ops/tpu/random/threefry.py

Lines changed: 6 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -14,54 +14,17 @@
1414
"""Implementation of the Threefry PRNG as a Pallas kernel."""
1515
from typing import Sequence
1616
import jax
17-
from jax import lax
1817
from jax._src import prng
1918
from jax.experimental import pallas as pl
2019
from jax.experimental.pallas import tpu as pltpu
2120
import jax.numpy as jnp
2221
import numpy as np
22+
from jax.experimental.pallas.ops.tpu.random import prng_utils
2323

2424
Shape = Sequence[int]
2525

2626
BLOCK_SIZE = (256, 256)
2727

28-
_round_up = lambda x, y: (x + y - 1) // y * y
29-
30-
31-
def blocked_iota(block_shape: Shape,
32-
total_shape: Shape):
33-
"""Computes a sub-block of a larger shaped iota.
34-
35-
Args:
36-
block_shape: The output block shape of the iota.
37-
total_shape: The total shape of the input tensor.
38-
Returns:
39-
Result of the blocked iota.
40-
"""
41-
iota_data = jnp.zeros(block_shape, dtype=jnp.uint32)
42-
multiplier = 1
43-
for dim in range(len(block_shape)-1, -1, -1):
44-
block_mult = 1
45-
counts_lo = lax.broadcasted_iota(
46-
dtype=jnp.uint32, shape=block_shape, dimension=dim
47-
)
48-
iota_data += counts_lo * multiplier * block_mult
49-
multiplier *= total_shape[dim]
50-
return iota_data
51-
52-
53-
def _compute_scalar_offset(iteration_index,
54-
total_size: Shape,
55-
block_size: Shape):
56-
ndims = len(iteration_index)
57-
dim_size = 1
58-
total_idx = 0
59-
for i in range(ndims-1, -1, -1):
60-
dim_idx = iteration_index[i] * block_size[i]
61-
total_idx += dim_idx * dim_size
62-
dim_size *= total_size[i]
63-
return total_idx
64-
6528

6629
def threefry_2x32_count(key,
6730
shape: Shape,
@@ -97,8 +60,9 @@ def threefry_2x32_count(key,
9760

9861
def kernel(key_ref, out_ref):
9962
counts_idx = tuple(pl.program_id(i) for i in range(len(grid_dims)))
100-
offset = _compute_scalar_offset(counts_idx, unpadded_shape, block_shape)
101-
counts_lo = blocked_iota(block_size, unpadded_shape)
63+
offset = prng_utils.compute_scalar_offset(
64+
counts_idx, unpadded_shape, block_shape)
65+
counts_lo = prng_utils.blocked_iota(block_size, unpadded_shape)
10266
counts_lo = counts_lo + offset
10367
counts_lo = counts_lo.astype(jnp.uint32)
10468
# TODO(justinfu): Support hi bits on count.
@@ -134,8 +98,8 @@ def plthreefry_random_bits(key, bit_width: int, shape: Shape):
13498
shape[-2] % BLOCK_SIZE[-2] != 0) or (shape[-1] % BLOCK_SIZE[-1] != 0)
13599
if requires_pad:
136100
padded_shape = tuple(shape[:-2]) + (
137-
_round_up(shape[-2], BLOCK_SIZE[-2]),
138-
_round_up(shape[-1], BLOCK_SIZE[-1]),
101+
prng_utils.round_up(shape[-2], BLOCK_SIZE[-2]),
102+
prng_utils.round_up(shape[-1], BLOCK_SIZE[-1]),
139103
)
140104
padded_result = threefry_2x32_count(
141105
key, padded_shape, shape, block_size=BLOCK_SIZE)

0 commit comments

Comments
 (0)