Skip to content

Commit 9439f63

Browse files
justinjfujax authors
authored andcommitted
[Pallas] Add pallas TPU random key impls and lowering rules for basic prng ops (seed/foldin/bits/unwrap/wrap).
PiperOrigin-RevId: 642085019
1 parent 3d4ee0d commit 9439f63

File tree

8 files changed

+433
-5
lines changed

8 files changed

+433
-5
lines changed

jax/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,7 @@ pytype_strict_library(
631631
"//jax/_src/pallas/mosaic:pallas_call_registration", # build_cleaner: keep
632632
"//jax/_src/pallas/mosaic:pipeline",
633633
"//jax/_src/pallas/mosaic:primitives",
634+
"//jax/_src/pallas/mosaic:random",
634635
],
635636
)
636637

jax/_src/pallas/mosaic/BUILD

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,13 @@ py_library(
9797
"//jax/_src/pallas",
9898
],
9999
)
100+
101+
py_library(
102+
name = "random",
103+
srcs = ["random.py"],
104+
deps = [
105+
":primitives",
106+
"//jax",
107+
"//jax:typing",
108+
] + py_deps("numpy"),
109+
)

jax/_src/pallas/mosaic/lowering.py

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from jax._src import linear_util as lu
3232
from jax._src import mesh as mesh_lib
3333
from jax._src import pjit
34+
from jax._src import prng
3435
from jax._src import source_info_util
3536
from jax._src import state
3637
from jax._src.interpreters import mlir
@@ -76,6 +77,12 @@
7677
map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin
7778
zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin
7879

80+
UNSIGNED_TO_SIGNED = {
81+
np.dtype('uint8'): np.dtype('int8'),
82+
np.dtype('uint16'): np.dtype('int16'),
83+
np.dtype('uint32'): np.dtype('int32'),
84+
np.dtype('uint64'): np.dtype('int64'),
85+
}
7986

8087
@dataclasses.dataclass
8188
class MeshContext:
@@ -123,7 +130,13 @@ def _dtype_to_ir_type(dtype: jnp.dtype) -> ir.Type:
123130
return ir.Type.parse("!tpu.semaphore")
124131
else:
125132
raise NotImplementedError
126-
return mlir.dtype_to_ir_type(dtype)
133+
# TODO(justinfu): Remove after mosaic supports unsigned types.
134+
# This conversion makes mosaic interpret all unsigned types as signed types.
135+
type = mlir.dtype_to_ir_type(dtype)
136+
if isinstance(type, ir.IntegerType):
137+
return ir.IntegerType.get_signless(type.width)
138+
else:
139+
return type
127140

128141
def aval_to_ir_type(aval, shape=None, memory_space: TPUMemorySpace | None = None):
129142
if isinstance(aval, tpu_core.AbstractSemaphore):
@@ -1214,14 +1227,14 @@ def _convert_helper(x, *, to_dtype):
12141227
if jnp.issubdtype(from_dtype, jnp.dtype("bool")):
12151228
x = x.astype(jnp.int32)
12161229
return _convert_helper(x, to_dtype=to_dtype)
1217-
if jnp.issubdtype(from_dtype, jnp.integer):
1230+
if jnp.issubdtype(from_dtype, jnp.signedinteger):
12181231
if from_dtype.itemsize < 4:
12191232
x = x.astype(jnp.int32)
12201233
if jnp.issubdtype(to_dtype, jnp.floating) and to_dtype.itemsize < 4:
12211234
x = x.astype(jnp.float32)
12221235
return x.astype(to_dtype)
12231236
if jnp.issubdtype(from_dtype, jnp.floating):
1224-
if jnp.issubdtype(to_dtype, jnp.integer):
1237+
if jnp.issubdtype(to_dtype, jnp.signedinteger):
12251238
if from_dtype.itemsize < 4:
12261239
x = x.astype(jnp.float32)
12271240
if to_dtype.itemsize < 4:
@@ -1242,6 +1255,11 @@ def _convert_element_type_lowering_rule(
12421255
out_aval = ctx.avals_out[0]
12431256
old_dtype = ctx.avals_in[0].dtype
12441257
out_type = aval_to_ir_type(out_aval)
1258+
1259+
# TODO(justinfu): Remove after mosaic supports unsigned types.
1260+
# This conversion makes mosaic interpret all unsigned types as signed types.
1261+
if np.issubdtype(new_dtype, jnp.unsignedinteger):
1262+
new_dtype = UNSIGNED_TO_SIGNED[new_dtype]
12451263
if old_dtype == new_dtype:
12461264
return x
12471265
if jnp.issubdtype(old_dtype, jnp.floating) and jnp.issubdtype(
@@ -2158,9 +2176,16 @@ def _bitcast_lowering_rule(ctx: LoweringRuleContext, x, *, ty):
21582176
(out_aval,) = ctx.avals_out
21592177
return tpu.BitcastOp(aval_to_ir_type(out_aval), x).result
21602178

2161-
21622179
lowering_rules[tpu_primitives.bitcast_p] = _bitcast_lowering_rule
21632180

2181+
def _bitcast_convert_type_lowering_rule(
2182+
ctx: LoweringRuleContext, x, *, new_dtype):
2183+
(in_aval, ) = ctx.avals_in
2184+
(out_aval,) = ctx.avals_out
2185+
if in_aval.dtype.itemsize != new_dtype.itemsize:
2186+
raise NotImplementedError("Changing bitwidths not supported.")
2187+
return tpu.BitcastOp(aval_to_ir_type(out_aval), x).result
2188+
lowering_rules[lax.bitcast_convert_type_p] = _bitcast_convert_type_lowering_rule
21642189

21652190
def _alloc_value(aval: jax_core.AbstractValue) -> ir.Value:
21662191
if isinstance(aval, pl_core.AbstractMemoryRef):
@@ -2380,3 +2405,41 @@ def _prng_random_bits_lowering_rule(ctx: LoweringRuleContext, *, shape):
23802405
out_type = aval_to_ir_type(out_aval)
23812406
return tpu.PRNGRandomBitsOp(out_type).result
23822407
lowering_rules[tpu_primitives.prng_random_bits_p] = _prng_random_bits_lowering_rule
2408+
2409+
2410+
def random_seed_lowering(ctx, seeds, *, impl):
2411+
seed_lowering = lower_fun(
2412+
impl.seed, multiple_results=False)
2413+
return seed_lowering(ctx, seeds)
2414+
lowering_rules[prng.random_seed_p] = random_seed_lowering
2415+
2416+
2417+
def random_bits_lowering(ctx, keys, *, bit_width, shape):
2418+
assert bit_width == 32, "Only 32-bit PRNG supported."
2419+
aval, = ctx.avals_in
2420+
impl = aval.dtype._impl
2421+
bits_lowering = lower_fun(
2422+
impl.random_bits, multiple_results=False)
2423+
return bits_lowering(ctx, keys, bit_width=bit_width, shape=shape)
2424+
lowering_rules[prng.random_bits_p] = random_bits_lowering
2425+
2426+
2427+
def random_fold_in_lowering(ctx, keys, msgs):
2428+
keys_aval, _ = ctx.avals_in
2429+
impl = keys_aval.dtype._impl
2430+
fold_in_lowering = lower_fun(
2431+
impl.fold_in, multiple_results=False)
2432+
return fold_in_lowering(ctx, keys, msgs)
2433+
lowering_rules[prng.random_fold_in_p] = random_fold_in_lowering
2434+
2435+
2436+
def random_unwrap_lowering(ctx, key):
2437+
del ctx
2438+
return key
2439+
lowering_rules[prng.random_unwrap_p] = random_unwrap_lowering
2440+
2441+
2442+
def random_wrap_lowering(ctx, key_data, *, impl):
2443+
del ctx, impl
2444+
return key_data
2445+
lowering_rules[prng.random_wrap_p] = random_wrap_lowering

jax/_src/pallas/mosaic/primitives.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ def _(*_):
604604
return []
605605

606606

607-
def prng_seed(*seeds: tuple[int | jax.Array, ...]) -> None:
607+
def prng_seed(*seeds: int | jax.Array) -> None:
608608
"""Sets the seed for PRNG.
609609
610610
Args:

jax/_src/pallas/mosaic/random.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any, Callable, Union
15+
16+
import jax
17+
import numpy as np
18+
from jax import numpy as jnp
19+
from jax import random as jax_api_random
20+
from jax._src import typing
21+
from jax._src.pallas.mosaic.primitives import prng_seed
22+
from jax._src.pallas.mosaic.primitives import prng_random_bits
23+
from jax._src import prng as jax_prng
24+
25+
26+
Shape = jax_prng.Shape
27+
FOLD_IN_ROUNDS = 128
28+
SUPPORTED_CONVERSION_KEYS = ["rbg", "unsafe_rbg", "pallas"]
29+
30+
def to_pallas_key(key: jax_prng.PRNGKeyArray) -> jax_prng.PRNGKeyArray:
31+
"""Helper function for converting non-Pallas PRNG keys into Pallas keys."""
32+
33+
# Only allow conversion from RBG -> Pallas keys.
34+
# There is no technical reason why we cannot support Threefry here, but
35+
# this reduces the chance of unintended behavior where the pallas PRNG
36+
# produces different random bits than Threefry. RBG has fewer guarantees
37+
# so users of RBG should be more aware of the consequences.
38+
if key._impl.name not in SUPPORTED_CONVERSION_KEYS:
39+
raise ValueError(f"Unsupported key type: {key._impl.name}"
40+
f"Supported keys are: {SUPPORTED_CONVERSION_KEYS}")
41+
42+
key_data = jax_api_random.key_data(key)
43+
pallas_key_size = np.prod(tpu_key_impl.key_shape)
44+
if key_data.size < pallas_key_size:
45+
raise ValueError(f"Key data must be at least {pallas_key_size} bytes.")
46+
pallas_key_data = jnp.ravel(key_data)[:pallas_key_size]
47+
pallas_key_data = jnp.reshape(pallas_key_data, tpu_key_impl.key_shape)
48+
return jax_api_random.wrap_key_data(pallas_key_data, impl='pallas')
49+
50+
def _seed_func(seed: jnp.int32):
51+
seed_data = jnp.zeros(tpu_key_impl.key_shape, dtype=jnp.int32)
52+
return (seed_data + seed).astype(jnp.uint32)
53+
54+
def _random_bits(key: typing.Array, bit_width: int, shape: Shape):
55+
if bit_width != 32:
56+
raise NotImplementedError("Bit width must be 32")
57+
if isinstance(key.dtype, jax_prng.KeyTy):
58+
key_data = jax.random.key_data(key)
59+
else:
60+
key_data = key
61+
prng_seed(key_data[0, 0])
62+
return prng_random_bits(shape)
63+
64+
def _fold_in(key: jax_prng.PRNGKeyArray, data: typing.Array):
65+
# Roughly, we compute the new key as follows:
66+
# new_key = random_bits(data)[..., 127] ^ random_bits(old_key)[..., 127]
67+
# Because the TPU generates random numbers in (8, 128) blocks at once, we
68+
# can generate that many values without additional cost which will reduce
69+
# correlation between the old and new keys.
70+
key_data = jax.random.key_data(key)
71+
72+
prng_seed(data)
73+
data_bits = prng_random_bits(
74+
key_data.shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32)
75+
prng_seed(key_data[0, 0])
76+
key_bits = prng_random_bits(
77+
key_data.shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32)
78+
79+
mixed = key_bits[..., FOLD_IN_ROUNDS-1] ^ data_bits[..., FOLD_IN_ROUNDS-1]
80+
assert mixed.shape == key_data.shape
81+
impl: jax_prng.PRNGSpec = jax.random.key_impl(key) # type: ignore
82+
return jax.random.wrap_key_data(mixed, impl=impl)
83+
84+
def _split(key: typing.Array, shape: Shape):
85+
del key, shape
86+
raise NotImplementedError()
87+
88+
tpu_key_impl = jax_prng.PRNGImpl(
89+
# Use a 2D key since pallas only supports 2D tiling.
90+
key_shape=(1, 1),
91+
seed=_seed_func,
92+
split=_split,
93+
random_bits=_random_bits,
94+
fold_in=_fold_in,
95+
name="pallas",
96+
tag="pl"
97+
)
98+
jax_prng.register_prng(tpu_key_impl)
99+
100+
# Implementation of the stateful Pallas PRNG API.
101+
# Users should set the seed using the `set_seed` function,
102+
# and call the appropriate stateful sampling functions.
103+
# The actual key impl should never be used. The impl
104+
# serves as internal boilerplate code because JAX's existing
105+
# random functions expect a key as an argument, and
106+
# the keys are only generated as part of unused arguments.
107+
108+
def _pl_stateful_seed_func(seed: jnp.int32):
109+
del seed
110+
# Unused. Return the correct shape and dtype.
111+
return jnp.empty((), dtype=jnp.int32)
112+
113+
def _pl_stateful_random_bits(key: typing.Array, bit_width: int, shape: Shape):
114+
del key
115+
assert bit_width == 32, "Bit width must be 32"
116+
return prng_random_bits(shape)
117+
118+
def _pl_stateful_fold_in(key: typing.Array, data: typing.Array):
119+
del key, data
120+
raise NotImplementedError()
121+
122+
def _pl_stateful_split(key: typing.Array, shape: Shape):
123+
del key, shape
124+
raise NotImplementedError()
125+
126+
127+
tpu_internal_stateful_impl = jax_prng.PRNGImpl(
128+
key_shape=(),
129+
seed=_pl_stateful_seed_func,
130+
split=_pl_stateful_split,
131+
random_bits=_pl_stateful_random_bits,
132+
fold_in=_pl_stateful_fold_in,
133+
name="_pallas_internal_stateful",
134+
tag="_pl_stateful"
135+
)
136+
jax_prng.register_prng(tpu_internal_stateful_impl)
137+
138+
def set_seed(seed: Union[jnp.int32, jax.Array]):
139+
"""Sets the seed for PRNG.
140+
141+
Args:
142+
seeds: An integer seed for setting the PRNG seed.
143+
"""
144+
if isinstance(seed, jax.Array):
145+
if seed.ndim != 1:
146+
raise ValueError("Seed data must be a scalar or 1D array")
147+
# TODO(justinfu): Mosaic currently only supports indexing by 0
148+
# for scalar results when using vector.extract
149+
# After support is added, use all seed data.
150+
prng_seed(seed[0])
151+
else:
152+
prng_seed(seed)
153+
154+
155+
SampleFnType = Any
156+
KeylessSampleFnType = Callable[..., jax.Array]
157+
158+
def _make_stateful_sampler(sampler: SampleFnType) -> KeylessSampleFnType:
159+
"""Converts a jax.random sampling function to a stateful version.
160+
161+
Args:
162+
sampler: A sampling function that consumes a key and returns
163+
random samples.
164+
165+
Returns:
166+
A stateful sampling function with the key argument removed.
167+
"""
168+
def new_sampler(*args, **kwargs):
169+
# Pass in a placeholder key into the sampling function.
170+
# The key is ignored by the stateful random_bits function, but all jax
171+
# sampling functions expect a key as input so we must pass one in here.
172+
placeholder_key = jax_prng.random_seed(
173+
None, impl=tpu_internal_stateful_impl)
174+
return sampler(placeholder_key, *args, **kwargs)
175+
# Remove key argument from docstring.
176+
doc_lines = filter(
177+
lambda line: 'key:' not in line, sampler.__doc__.split('\n'))
178+
new_sampler.__doc__ = '\n'.join(doc_lines)
179+
return new_sampler
180+
181+
bits = _make_stateful_sampler(jax_api_random.bits)
182+
uniform = _make_stateful_sampler(jax_api_random.uniform)
183+
bernoulli = _make_stateful_sampler(jax_api_random.bernoulli)

tests/pallas/tpu/BUILD

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
load(
16+
"//jaxlib:jax.bzl",
17+
"jax_generate_backend_suites",
18+
"jax_test",
19+
"py_deps",
20+
)
21+
22+
licenses(["notice"])
23+
24+
package(
25+
default_applicable_licenses = [],
26+
default_visibility = ["//visibility:private"],
27+
)
28+
29+
jax_generate_backend_suites()
30+
31+
jax_test(
32+
name = "pallas_random_test",
33+
srcs = [
34+
"pallas_random_test.py",
35+
],
36+
disable_backends = [
37+
"cpu",
38+
"gpu",
39+
],
40+
deps = [
41+
"//jax:pallas",
42+
"//jax:pallas_tpu",
43+
"//jax/_src/pallas/mosaic:random",
44+
"//third_party/py/absl/testing:absltest",
45+
"//third_party/py/absl/testing:parameterized",
46+
] + py_deps("numpy"),
47+
)

0 commit comments

Comments
 (0)