Skip to content

Commit 72e5ca9

Browse files
maxwillzqGoogle-ML-Automation
authored andcommitted
[JAX] Fix a small bug if shardings is tuple.
# Details `jax.tree.map` requests all its arguments to have the same data type. From ```[None] * len(tensorstore_specs) if global_shapes is None else global_shapes```, The data type is already decided to be a list. So if we pass `sharding` or `tspecs` as a tuple, it will fail. Here we add an explicit conversion to a list for sharding and tspecs. PiperOrigin-RevId: 707576866
1 parent 464e5a2 commit 72e5ca9

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

jax/experimental/array_serialization/serialization.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -420,10 +420,9 @@ def run_deserialization(shardings: Sequence[sharding.Sharding | Layout],
420420
async def _run_deserializer():
421421
# Object should be created once per process.
422422
byte_limiter = _LimitInFlightBytes(concurrent_bytes)
423-
424423
future_arrays = jax.tree_util.tree_map(
425424
partial(async_deserialize, byte_limiter=byte_limiter),
426-
shardings, tensorstore_specs,
425+
list(shardings), list(tensorstore_specs),
427426
[None] * len(tensorstore_specs) if global_shapes is None else global_shapes,
428427
[None] * len(tensorstore_specs) if dtypes is None else dtypes)
429428
return await asyncio.gather(*future_arrays)

jax/experimental/array_serialization/serialization_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,40 @@ class CheckpointTest(jtu.JaxTestCase):
5050
def _on_commit_callback(self, temp_ckpt_dir, final_ckpt_dir):
5151
os.rename(temp_ckpt_dir, final_ckpt_dir)
5252

53+
def test_deserialize_on_array_list(self):
54+
global_mesh = jtu.create_mesh((2, 4), ('x', 'y'))
55+
inp_shape = (16, 64)
56+
pspec = P('x', 'y')
57+
sharding = NamedSharding(global_mesh, pspec)
58+
inputs = []
59+
lambda_fn = lambda idx: src[idx]
60+
num_arrays = 5
61+
for _ in range(num_arrays):
62+
src = jax.random.normal(jax.random.key(0), inp_shape)
63+
inp = array.make_array_from_callback(inp_shape, sharding, lambda_fn)
64+
inputs.append(inp)
65+
ckpt_dir = pathlib.Path(self.create_tempdir().full_path)
66+
tspecs = [
67+
serialization.get_tensorstore_spec(f'{ckpt_dir}/array_{i}')
68+
for i in range(num_arrays)
69+
]
70+
inputs = tuple(inputs)
71+
tspecs = tuple(tspecs)
72+
manager = serialization.GlobalAsyncCheckpointManager()
73+
manager.serialize(
74+
inputs,
75+
tspecs,
76+
on_commit_callback=partial(
77+
self._on_commit_callback, ckpt_dir, ckpt_dir
78+
),
79+
)
80+
manager.wait_until_finished()
81+
shardings = tuple([sharding] * num_arrays)
82+
restored_arrays = manager.deserialize(shardings, tspecs)
83+
self.assertLen(restored_arrays, num_arrays)
84+
for inp, deserialized_array in zip(inputs, restored_arrays):
85+
self.assertArraysEqual(deserialized_array, inp)
86+
5387
@jtu.skip_on_devices('cpu')
5488
def test_memory_consumption(self):
5589
global_mesh = jtu.create_mesh((2, 4), ('x', 'y'))

0 commit comments

Comments
 (0)