Skip to content

Commit 156e076

Browse files
Merge pull request #33492 from gnecula:export_memories
PiperOrigin-RevId: 836271707
2 parents b342917 + 5a01e41 commit 156e076

File tree

5 files changed

+91
-41
lines changed

5 files changed

+91
-41
lines changed

jax/_src/export/serialization.fbs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,18 @@ enum DType: byte {
8282
key_unsafe_rbg = 29,
8383
}
8484

85+
enum MemorySpace: byte {
86+
Missing = 0, // default if missing (pre 11/25/2025)
87+
Device = 1,
88+
Host = 2,
89+
Any = 3,
90+
}
91+
8592
table AbstractValue {
8693
kind: AbstractValueKind;
8794
shape: [string]; // Support shape polymorphism
8895
dtype: DType;
96+
memory_space: MemorySpace;
8997
}
9098

9199
enum ShardingKind: byte {
@@ -119,14 +127,15 @@ table Exported {
119127
/// Note that this field has different semantics and purpose from
120128
/// `mlir_module_serialization_version`, which encodes
121129
/// the calling convention of the `mlir_module_serialized`.
130+
/// See comments in serialization.py for more details.
122131
serialization_version: uint16;
123132

124133
function_name: string;
125134
in_tree: PyTreeDef;
126135
in_avals: [AbstractValue];
127136
out_tree: PyTreeDef;
128137
out_avals: [AbstractValue];
129-
nr_devices: short;
138+
nr_devices_short: short; // Deprecated as of 11/25/2025
130139
in_shardings: [Sharding];
131140
out_shardings: [Sharding];
132141

@@ -142,6 +151,7 @@ table Exported {
142151
uses_global_constants: bool;
143152

144153
vjp: Exported;
154+
nr_devices: uint32 = 0; // Added 11/25/2025
145155
}
146156

147157
root_type Exported;

jax/_src/export/serialization.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,15 @@
5050
# This version is backwards compatible with Version 2.
5151
# Version 4, April 7th, 2025, adds serialization for PRNGs key types.
5252
# This version is backwards compatible with Version 2 and 3.
53-
_SERIALIZATION_VERSION = 2
53+
# Version 5, November 23rd, 2025, adds serialization for aval memory_space,
54+
# upgrade num_devices to a 32 bit value.
55+
# This version is backwards compatible with Version 2 to 4.
56+
# TODO(necula): we cannot really store the actual serialization_version
57+
# in the flatbuffer because prior to 11/25/2025 deserializers checked
58+
# if the version is 2 or 3. I have now removed that check, but for the
59+
# sake of old deserializers we can only store version 3. Starting
60+
# on January 2026 we can store the actual version.
61+
_SERIALIZATION_VERSION = 3
5462

5563
def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray:
5664
"""Serializes an Exported.
@@ -157,10 +165,6 @@ def _serialize_array(
157165

158166
def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported:
159167
serialization_version = exp.SerializationVersion()
160-
if serialization_version not in [2, 3]:
161-
raise NotImplementedError(
162-
f"deserialize unsupported version {serialization_version}"
163-
)
164168

165169
fun_name = exp.FunctionName().decode("utf-8")
166170
in_tree = tree_util.tree_structure(
@@ -177,7 +181,10 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported:
177181
out_avals = _deserialize_tuple(
178182
exp.OutAvalsLength, exp.OutAvals, deser_aval
179183
)
180-
nr_devices = exp.NrDevices()
184+
# TODO(necula): remove the fallback to NrDevicesShort and mark
185+
# the field "deprecated" once we abandon the old
186+
# serialization format (6 months after 11/24/2025).
187+
nr_devices = exp.NrDevices() or exp.NrDevicesShort()
181188
in_shardings = _deserialize_tuple(
182189
exp.InShardingsLength, exp.InShardings, _deserialize_sharding
183190
)
@@ -381,6 +388,14 @@ def register_dtype_kind(dtype: Any, kind: int):
381388
_dtype_kind_to_dtype[kind] = dtype
382389

383390

391+
_memory_space_to_enum = {
392+
core.MemorySpace.Device: ser_flatbuf.MemorySpace.Device,
393+
core.MemorySpace.Host: ser_flatbuf.MemorySpace.Host,
394+
core.MemorySpace.Any: ser_flatbuf.MemorySpace.Any,
395+
}
396+
_memory_space_from_enum = {v: k for k, v in _memory_space_to_enum.items()}
397+
398+
384399
def _serialize_aval(
385400
builder: flatbuffers.Builder, aval: core.ShapedArray
386401
) -> int:
@@ -395,6 +410,7 @@ def _serialize_aval(
395410
ser_flatbuf.AbstractValueAddKind(builder, aval_kind)
396411
ser_flatbuf.AbstractValueAddShape(builder, shape_vector_offset)
397412
ser_flatbuf.AbstractValueAddDtype(builder, _dtype_to_dtype_kind[aval.dtype])
413+
ser_flatbuf.AbstractValueAddMemorySpace(builder, _memory_space_to_enum[aval.memory_space])
398414
return ser_flatbuf.AbstractValueEnd(builder)
399415

400416

@@ -409,7 +425,8 @@ def _deserialize_aval(aval: ser_flatbuf.AbstractValue,
409425
),
410426
scope=scope
411427
)
412-
return core.ShapedArray(shape, dtype)
428+
mem_space = aval.MemorySpace() or ser_flatbuf.MemorySpace.Device
429+
return core.ShapedArray(shape, dtype, memory_space=_memory_space_from_enum[mem_space])
413430
else:
414431
assert False, aval_kind
415432

jax/_src/export/serialization_generated.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from flatbuffers.compat import import_numpy
2222
np = import_numpy()
2323

24-
class PyTreeDefKind:
24+
class PyTreeDefKind(object):
2525
leaf = 0
2626
none = 1
2727
tuple = 2
@@ -30,12 +30,12 @@ class PyTreeDefKind:
3030
custom = 5
3131

3232

33-
class AbstractValueKind:
33+
class AbstractValueKind(object):
3434
shapedArray = 0
3535
abstractToken = 1
3636

3737

38-
class DType:
38+
class DType(object):
3939
bool = 0
4040
i8 = 1
4141
i16 = 2
@@ -68,18 +68,25 @@ class DType:
6868
key_unsafe_rbg = 29
6969

7070

71-
class ShardingKind:
71+
class MemorySpace(object):
72+
Missing = 0
73+
Device = 1
74+
Host = 2
75+
Any = 3
76+
77+
78+
class ShardingKind(object):
7279
unspecified = 0
7380
hlo_sharding = 1
7481

7582

76-
class DisabledSafetyCheckKind:
83+
class DisabledSafetyCheckKind(object):
7784
platform = 0
7885
custom_call = 1
7986
shape_assertions = 2
8087

8188

82-
class PyTreeDef:
89+
class PyTreeDef(object):
8390
__slots__ = ['_tab']
8491

8592
@classmethod
@@ -214,7 +221,7 @@ def PyTreeDefEnd(builder):
214221

215222

216223

217-
class AbstractValue:
224+
class AbstractValue(object):
218225
__slots__ = ['_tab']
219226

220227
@classmethod
@@ -266,8 +273,15 @@ def Dtype(self):
266273
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
267274
return 0
268275

276+
# AbstractValue
277+
def MemorySpace(self):
278+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
279+
if o != 0:
280+
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
281+
return 0
282+
269283
def AbstractValueStart(builder):
270-
builder.StartObject(3)
284+
builder.StartObject(4)
271285

272286
def AbstractValueAddKind(builder, kind):
273287
builder.PrependInt8Slot(0, kind, 0)
@@ -281,12 +295,15 @@ def AbstractValueStartShapeVector(builder, numElems):
281295
def AbstractValueAddDtype(builder, dtype):
282296
builder.PrependInt8Slot(2, dtype, 0)
283297

298+
def AbstractValueAddMemorySpace(builder, memorySpace):
299+
builder.PrependInt8Slot(3, memorySpace, 0)
300+
284301
def AbstractValueEnd(builder):
285302
return builder.EndObject()
286303

287304

288305

289-
class Sharding:
306+
class Sharding(object):
290307
__slots__ = ['_tab']
291308

292309
@classmethod
@@ -355,7 +372,7 @@ def ShardingEnd(builder):
355372

356373

357374

358-
class Effect:
375+
class Effect(object):
359376
__slots__ = ['_tab']
360377

361378
@classmethod
@@ -391,7 +408,7 @@ def EffectEnd(builder):
391408

392409

393410

394-
class DisabledSafetyCheck:
411+
class DisabledSafetyCheck(object):
395412
__slots__ = ['_tab']
396413

397414
@classmethod
@@ -437,7 +454,7 @@ def DisabledSafetyCheckEnd(builder):
437454

438455

439456

440-
class Exported:
457+
class Exported(object):
441458
__slots__ = ['_tab']
442459

443460
@classmethod
@@ -460,6 +477,7 @@ def Init(self, buf, pos):
460477
# Note that this field has different semantics and purpose from
461478
# `mlir_module_serialization_version`, which encodes
462479
# the calling convention of the `mlir_module_serialized`.
480+
# See comments in serialization.py for more details.
463481
# Exported
464482
def SerializationVersion(self):
465483
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
@@ -543,7 +561,7 @@ def OutAvalsIsNone(self):
543561
return o == 0
544562

545563
# Exported
546-
def NrDevices(self):
564+
def NrDevicesShort(self):
547565
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
548566
if o != 0:
549567
return self._tab.Get(flatbuffers.number_types.Int16Flags, o + self._tab.Pos)
@@ -767,8 +785,15 @@ def Vjp(self):
767785
return obj
768786
return None
769787

788+
# Exported
789+
def NrDevices(self):
790+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(40))
791+
if o != 0:
792+
return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
793+
return 0
794+
770795
def ExportedStart(builder):
771-
builder.StartObject(18)
796+
builder.StartObject(19)
772797

773798
def ExportedAddSerializationVersion(builder, serializationVersion):
774799
builder.PrependUint16Slot(0, serializationVersion, 0)
@@ -794,8 +819,8 @@ def ExportedAddOutAvals(builder, outAvals):
794819
def ExportedStartOutAvalsVector(builder, numElems):
795820
return builder.StartVector(4, numElems, 4)
796821

797-
def ExportedAddNrDevices(builder, nrDevices):
798-
builder.PrependInt16Slot(6, nrDevices, 0)
822+
def ExportedAddNrDevicesShort(builder, nrDevicesShort):
823+
builder.PrependInt16Slot(6, nrDevicesShort, 0)
799824

800825
def ExportedAddInShardings(builder, inShardings):
801826
builder.PrependUOffsetTRelativeSlot(7, flatbuffers.number_types.UOffsetTFlags.py_type(inShardings), 0)
@@ -854,5 +879,8 @@ def ExportedAddUsesGlobalConstants(builder, usesGlobalConstants):
854879
def ExportedAddVjp(builder, vjp):
855880
builder.PrependUOffsetTRelativeSlot(17, flatbuffers.number_types.UOffsetTFlags.py_type(vjp), 0)
856881

882+
def ExportedAddNrDevices(builder, nrDevices):
883+
builder.PrependUint32Slot(18, nrDevices, 0)
884+
857885
def ExportedEnd(builder):
858886
return builder.EndObject()

tests/export_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,6 +1405,18 @@ def f_with_sharding(x):
14051405
f"context with {jax.local_device_count()} devices"):
14061406
exp.call(b)
14071407

1408+
def test_memory_space(self):
1409+
shd = jax.sharding.SingleDeviceSharding(
1410+
jax.devices()[0], memory_kind="pinned_host")
1411+
a = jax.device_put(1, shd)
1412+
f = jax.jit(lambda x: x)
1413+
1414+
exported = get_exported(f, platforms=("tpu", "cuda"))(a)
1415+
self.assertEqual(exported.in_avals[0].memory_space, core.MemorySpace.Host)
1416+
if jtu.device_under_test in ("tpu", "gpu"):
1417+
b = exported.call(a)
1418+
self.assertEqual(b.sharding, a.sharding)
1419+
14081420
@jtu.parameterized_filterable(
14091421
kwargs=[
14101422
dict(testcase_name=f"_poly={poly}", poly=poly)

tests/memories_test.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -743,23 +743,6 @@ def test_host_to_device_transfer(self):
743743
self.assertEqual(d.sharding.memory_kind, 'device')
744744
self.assertArraysEqual(d, orig)
745745

746-
def test_memory_space_propagated_identity_jit(self):
747-
shd = jax.sharding.SingleDeviceSharding(
748-
jax.devices()[0], memory_kind='pinned_host')
749-
a = jax.device_put(1, shd)
750-
751-
f = jax.jit(lambda x: x, out_shardings=shd)
752-
b = f(a)
753-
self.assertEqual(b.sharding, a.sharding)
754-
755-
f = jax.jit(lambda x: x)
756-
b = f(a)
757-
self.assertEqual(b.sharding, a.sharding)
758-
759-
exported = jax.export.export(f)(a)
760-
b = exported.call(a)
761-
self.assertEqual(b.sharding, a.sharding)
762-
763746

764747
class ComputeOffload(jtu.BufferDonationTestCase):
765748

0 commit comments

Comments
 (0)