Skip to content

Commit ea7fa29

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Allow tuple(arrays) as an input to make_array_from_single_device_arrays. Fixes jax-ml#27303
PiperOrigin-RevId: 738917340
1 parent 59e480d commit ea7fa29

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

jax/_src/array.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,7 +1024,7 @@ def make_array_from_single_device_arrays(
10241024
shape : Shape of the output ``jax.Array``. This conveys information already included with
10251025
``sharding`` and ``arrays`` and serves as a double check.
10261026
sharding: Sharding: A global Sharding instance which describes how the output jax.Array is laid out across devices.
1027-
arrays: Sequence of ``jax.Array``\s that are each single device addressable. ``len(arrays)``
1027+
arrays: `list` or `tuple` of ``jax.Array``\s that are each single device addressable. ``len(arrays)``
10281028
must equal ``len(sharding.addressable_devices)`` and the shape of each array must be the same. For multiprocess code,
10291029
each process will call with a different ``arrays`` argument that corresponds to that processes' data.
10301030
These arrays are commonly created via ``jax.device_put``.
@@ -1071,14 +1071,15 @@ def make_array_from_single_device_arrays(
10711071
if dtypes.issubdtype(aval.dtype, dtypes.extended):
10721072
return aval.dtype._rules.make_sharded_array(aval, sharding, arrays,
10731073
committed=True)
1074+
arrays = list(arrays) if isinstance(arrays, tuple) else arrays
10741075
# TODO(phawkins): ideally the cast() could be checked.
10751076
try:
10761077
return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays),
10771078
committed=True)
10781079
except TypeError:
1079-
if not isinstance(arrays, Sequence):
1080+
if not isinstance(arrays, list):
10801081
raise TypeError("jax.make_array_from_single_device_arrays `arrays` "
1081-
"argument must be a Sequence (list or tuple), but got "
1082+
"argument must be a list or tuple, but got "
10821083
f"{type(arrays)}.")
10831084
if any(isinstance(arr, core.Tracer) for arr in arrays):
10841085
raise ValueError(

tests/array_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,6 +1301,18 @@ def f(x):
13011301
with self.assertRaisesRegex(TypeError, msg):
13021302
jax.jit(f)(x)
13031303

1304+
def test_make_array_from_single_device_arrays_tuple(self):
1305+
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
1306+
shape = (8, 8)
1307+
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
1308+
inp_data = np.arange(math.prod(shape)).reshape(shape)
1309+
1310+
arrays = tuple(
1311+
jax.device_put(inp_data[index], d)
1312+
for d, index in s.addressable_devices_indices_map(shape).items())
1313+
1314+
jax.make_array_from_single_device_arrays(shape, s, arrays) # doesn't crash
1315+
13041316
def test_make_array_from_single_device_arrays_bad_inputs(self):
13051317
x = jnp.arange(10)
13061318
mesh = jtu.create_mesh((2,), ('x',))

0 commit comments

Comments
 (0)