Skip to content

Commit c2eaedf

Browse files
Merge pull request jax-ml#27776 from gnecula:export_keys
PiperOrigin-RevId: 745038060
2 parents 19fcae1 + ce7dc85 commit c2eaedf

File tree

5 files changed

+30
-4
lines changed

5 files changed

+30
-4
lines changed

jax/_src/export/serialization.fbs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ enum AbstractValueKind: byte {
4545
}
4646

4747
enum DType: byte {
48-
// Last used id: 22
48+
// Last used id: 29
4949
bool = 0,
5050
i8 = 1,
5151
i16 = 2,
@@ -76,6 +76,10 @@ enum DType: byte {
7676
f8_e5m2fnuz = 21,
7777
f8_e8m0fnu = 25,
7878
f4_e2m1fn = 26,
79+
80+
key_fry = 27,
81+
key_rbg = 28,
82+
key_unsafe_rbg = 29,
7983
}
8084

8185
table AbstractValue {

jax/_src/export/serialization.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from jax._src import core
3232
from jax._src import dtypes
3333
from jax._src import effects
34+
from jax._src import prng
3435
from jax._src import tree_util
3536
from jax._src.export import serialization_generated as ser_flatbuf
3637
from jax._src.export import _export
@@ -48,6 +49,8 @@
4849
# Version 2, Dec 16th, 2023, adds the f0 dtype.
4950
# Version 3, October 16th, 2024, adds serialization for namedtuple and custom types
5051
# This version is backwards compatible with Version 2.
52+
# Version 4, April 7th, 2025, adds serialization for PRNGs key types.
53+
# This version is backwards compatible with Version 2 and 3.
5154
_SERIALIZATION_VERSION = 2
5255

5356
def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray:
@@ -361,6 +364,10 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef):
361364
dtypes._float8_e4m3_dtype: ser_flatbuf.DType.f8_e4m3,
362365
dtypes._float8_e8m0fnu_dtype: ser_flatbuf.DType.f8_e8m0fnu,
363366
dtypes._float4_e2m1fn_dtype: ser_flatbuf.DType.f4_e2m1fn,
367+
368+
prng.KeyTy(prng.prngs["threefry2x32"]): ser_flatbuf.DType.key_fry,
369+
prng.KeyTy(prng.prngs["rbg"]): ser_flatbuf.DType.key_rbg,
370+
prng.KeyTy(prng.prngs["unsafe_rbg"]): ser_flatbuf.DType.key_unsafe_rbg,
364371
}
365372

366373
_dtype_kind_to_dtype = {

jax/_src/export/serialization_generated.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,19 @@ class DType(object):
5353
bf16 = 14
5454
i4 = 15
5555
ui4 = 16
56-
f8_e3m4 = 24
57-
f8_e4m3 = 23
5856
f8_e4m3b11fnuz = 17
5957
f8_e4m3fn = 18
6058
f8_e4m3fnuz = 19
6159
f8_e5m2 = 20
6260
f8_e5m2fnuz = 21
6361
f0 = 22
62+
f8_e4m3 = 23
63+
f8_e3m4 = 24
6464
f8_e8m0fnu = 25
6565
f4_e2m1fn = 26
66+
key_fry = 27
67+
key_rbg = 28
68+
key_unsafe_rbg = 29
6669

6770

6871
class ShardingKind(object):

jax/_src/prng.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def pprint(self):
113113
]))))
114114

115115

116-
prngs = {}
116+
prngs: dict[str, PRNGImpl] = {}
117117

118118
def register_prng(impl: PRNGImpl):
119119
if impl.name in prngs:

tests/export_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,18 @@ def f(x1, x2):
421421
self.assertEqual(tree_util.tree_structure(res2),
422422
tree_util.tree_structure(res))
423423

424+
@jtu.parameterized_filterable(
425+
kwargs=[dict(impl=p)
426+
for p in ("rbg", "unsafe_rbg", "threefry2x32")])
427+
def test_prng_keys(self, *, impl):
428+
429+
key = jax.random.key(42, impl=impl)
430+
@jax.jit
431+
def f(key):
432+
return key
433+
exp_f = get_exported(jax.jit(f))(key)
434+
self.assertEqual(f(key), exp_f.call(key))
435+
424436
def test_error_wrong_intree(self):
425437
def f(a_b_pair, *, c):
426438
return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c

0 commit comments

Comments
 (0)