Skip to content

Commit ae46b75

Browse files
Merge pull request #24593 from froystig:random-dtypes
PiperOrigin-RevId: 698268678
2 parents 4d60db1 + 4bb8107 commit ae46b75

File tree

3 files changed

+25
-18
lines changed

3 files changed

+25
-18
lines changed

jax/_src/random.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,14 +293,15 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray:
293293
return _return_prng_keys(wrapped, _split(typed_key, num))
294294

295295

296-
def _key_impl(keys: KeyArray) -> PRNGImpl:
296+
def _key_impl(keys: KeyArray) -> str | PRNGSpec:
297297
assert jnp.issubdtype(keys.dtype, dtypes.prng_key)
298298
keys_dtype = typing.cast(prng.KeyTy, keys.dtype)
299-
return keys_dtype._impl
299+
impl = keys_dtype._impl
300+
return impl.name if impl.name in prng.prngs else PRNGSpec(impl)
300301

301-
def key_impl(keys: KeyArrayLike) -> PRNGSpec:
302+
def key_impl(keys: KeyArrayLike) -> str | PRNGSpec:
302303
typed_keys, _ = _check_prng_key("key_impl", keys, allow_batched=True)
303-
return PRNGSpec(_key_impl(typed_keys))
304+
return _key_impl(typed_keys)
304305

305306

306307
def _key_data(keys: KeyArray) -> Array:

tests/extend_test.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,35 +73,41 @@ def test_symbols(self):
7373

7474
class RandomTest(jtu.JaxTestCase):
7575

76-
def test_key_make_with_custom_impl(self):
77-
shape = (4, 2, 7)
78-
76+
def make_custom_impl(self, shape, seed=False, split=False, fold_in=False,
77+
random_bits=False):
78+
assert not split and not fold_in and not random_bits # not yet implemented
7979
def seed_rule(_):
8080
return jnp.ones(shape, dtype=jnp.dtype('uint32'))
8181

8282
def no_rule(*args, **kwargs):
8383
assert False, 'unreachable'
8484

85-
impl = jex.random.define_prng_impl(
86-
key_shape=shape, seed=seed_rule, split=no_rule, fold_in=no_rule,
87-
random_bits=no_rule)
85+
return jex.random.define_prng_impl(
86+
key_shape=shape, seed=seed_rule if seed else no_rule, split=no_rule,
87+
fold_in=no_rule, random_bits=no_rule)
88+
89+
def test_key_make_with_custom_impl(self):
90+
impl = self.make_custom_impl(shape=(4, 2, 7), seed=True)
8891
k = jax.random.key(42, impl=impl)
8992
self.assertEqual(k.shape, ())
9093
self.assertEqual(impl, jax.random.key_impl(k))
9194

9295
def test_key_wrap_with_custom_impl(self):
93-
def no_rule(*args, **kwargs):
94-
assert False, 'unreachable'
95-
9696
shape = (4, 2, 7)
97-
impl = jex.random.define_prng_impl(
98-
key_shape=shape, seed=no_rule, split=no_rule, fold_in=no_rule,
99-
random_bits=no_rule)
97+
impl = self.make_custom_impl(shape=shape)
10098
data = jnp.ones((3, *shape), dtype=jnp.dtype('uint32'))
10199
k = jax.random.wrap_key_data(data, impl=impl)
102100
self.assertEqual(k.shape, (3,))
103101
self.assertEqual(impl, jax.random.key_impl(k))
104102

103+
def test_key_impl_is_spec(self):
104+
# this is counterpart to random_test.py:
105+
# KeyArrayTest.test_key_impl_builtin_is_string_name
106+
spec_ref = self.make_custom_impl(shape=(4, 2, 7), seed=True)
107+
key = jax.random.key(42, impl=spec_ref)
108+
spec = jax.random.key_impl(key)
109+
self.assertEqual(repr(spec), f"PRNGSpec({spec_ref._impl.name!r})")
110+
105111

106112
class FfiTest(jtu.JaxTestCase):
107113

tests/random_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,10 +1125,10 @@ class A: pass
11251125
jax.random.key(42, impl=A())
11261126

11271127
@jtu.sample_product(name=[name for name, _ in PRNG_IMPLS])
1128-
def test_key_spec_repr(self, name):
1128+
def test_key_impl_builtin_is_string_name(self, name):
11291129
key = jax.random.key(42, impl=name)
11301130
spec = jax.random.key_impl(key)
1131-
self.assertEqual(repr(spec), f"PRNGSpec({name!r})")
1131+
self.assertEqual(spec, name)
11321132

11331133
def test_keyarray_custom_vjp(self):
11341134
# Regression test for https://github.com/jax-ml/jax/issues/18442

0 commit comments

Comments
 (0)