@@ -1024,7 +1024,7 @@ def make_array_from_single_device_arrays(
10241024 shape : Shape of the output ``jax.Array``. This conveys information already included with
10251025 ``sharding`` and ``arrays`` and serves as a double check.
10261026 sharding: Sharding: A global Sharding instance which describes how the output jax.Array is laid out across devices.
1027- arrays: Sequence of ``jax.Array``\s that are each single device addressable. ``len(arrays)``
1027+ arrays: `list` or `tuple` of ``jax.Array``\s that are each single device addressable. ``len(arrays)``
10281028 must equal ``len(sharding.addressable_devices)`` and the shape of each array must be the same. For multiprocess code,
10291029 each process will call with a different ``arrays`` argument that corresponds to that processes' data.
10301030 These arrays are commonly created via ``jax.device_put``.
@@ -1071,14 +1071,15 @@ def make_array_from_single_device_arrays(
10711071 if dtypes .issubdtype (aval .dtype , dtypes .extended ):
10721072 return aval .dtype ._rules .make_sharded_array (aval , sharding , arrays ,
10731073 committed = True )
1074+ arrays = list (arrays ) if isinstance (arrays , tuple ) else arrays
10741075 # TODO(phawkins): ideally the cast() could be checked.
10751076 try :
10761077 return ArrayImpl (aval , sharding , cast (Sequence [ArrayImpl ], arrays ),
10771078 committed = True )
10781079 except TypeError :
1079- if not isinstance (arrays , Sequence ):
1080+ if not isinstance (arrays , list ):
10801081 raise TypeError ("jax.make_array_from_single_device_arrays `arrays` "
1081- "argument must be a Sequence ( list or tuple) , but got "
1082+ "argument must be a list or tuple, but got "
10821083 f"{ type (arrays )} ." )
10831084 if any (isinstance (arr , core .Tracer ) for arr in arrays ):
10841085 raise ValueError (
0 commit comments