@@ -673,6 +673,26 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array:
673673 return concatenate_p .bind (* operands , dimension = dimension )
674674
675675
676+ def split (operand : ArrayLike , sizes : Sequence [int ],
677+ axis : int = 0 ) -> Sequence [Array ]:
678+ """Splits an array along ``axis``.
679+
680+ Args:
681+ operand: an array to split
682+ sizes: the sizes of the split arrays. The sum of the sizes must be equal
683+ to the size of the ``axis`` dimension of ``operand``.
684+ axis: the axis along which to split the array.
685+
686+ Returns:
687+ A sequence of ``len(sizes)`` arrays. If ``sizes`` is
688+ ``[s1, s2, ...]``, this function returns chunks of sizes ``s1``, ``s2``,
689+ taken along ``axis``.
690+ """
691+ operand = asarray (operand )
692+ return split_p .bind (operand , sizes = tuple (sizes ),
693+ axis = canonicalize_axis (axis , operand .ndim ))
694+
695+
676696_precision_strings : dict [Any , Precision ] = {}
677697
678698class Precision (enum .Enum ):
@@ -4454,18 +4474,8 @@ def _concatenate_transpose_rule(t, *operands, dimension):
44544474 return [ad_util .Zero (o .aval ) if ad .is_undefined_primal (o ) else None
44554475 for o in operands ]
44564476 else :
4457- limit_points = np .cumsum (
4458- [shape [dimension ] for shape in operand_shapes ]).tolist ()
4459- starts = np .zeros ((len (operands ), t .ndim ), dtype = int ).tolist ()
4460- limits = np .tile (t .shape , (len (operands ), 1 )).tolist ()
4461-
4462- for i , s in enumerate (starts [1 :]):
4463- s [dimension ] = limit_points [:- 1 ][i ]
4464- for i , l in enumerate (limits ):
4465- l [dimension ] = limit_points [i ]
4466-
4467- return [slicing .slice (t , start , limit ) if ad .is_undefined_primal (o )
4468- else None for o , start , limit in zip (operands , starts , limits )]
4477+ return split (t , tuple (shape [dimension ] for shape in operand_shapes ),
4478+ axis = dimension )
44694479
44704480def _concatenate_batch_rule (batched_args , batch_dims , * , dimension ):
44714481 size = next (op .shape [bdim ] for op , bdim in zip (batched_args , batch_dims )
@@ -4499,6 +4509,76 @@ def _concatenate_lower(ctx, *xs, dimension):
44994509mlir .register_lowering (concatenate_p , _concatenate_lower )
45004510
45014511
4512+ def _split_shape_rule (operand , * , sizes , axis ):
4513+ shapes = []
4514+ shape = list (operand .shape )
4515+ if any (s < 0 for s in sizes ):
4516+ raise ValueError (
4517+ f"Sizes passed to split must be nonnegative, got { list (sizes )} " )
4518+ if operand .shape [axis ] != np .sum (sizes ):
4519+ raise ValueError (
4520+ f"Sum of sizes { np .sum (sizes )} must be equal to dimension { axis } of the "
4521+ f"operand shape { list (operand .shape )} " )
4522+ for size in sizes :
4523+ shape [axis ] = size
4524+ shapes .append (tuple (shape ))
4525+ return shapes
4526+
4527+ def _split_dtype_rule (operand , * , sizes , axis ):
4528+ return (operand .dtype ,) * len (sizes )
4529+
4530+ def _split_weak_type_rule (operand , * , sizes , axis ):
4531+ return (operand .weak_type ,) * len (sizes )
4532+
4533+ def _split_transpose_rule (cotangents , operand , * , sizes , axis ):
4534+ assert ad .is_undefined_primal (operand )
4535+ if all (type (t ) is ad_util .Zero for t in cotangents ):
4536+ return ad_util .Zero (operand .aval ),
4537+ cotangents = [
4538+ _zeros (t .aval ) if type (t ) is ad_util .Zero else t
4539+ for t in cotangents
4540+ ]
4541+ return concatenate (cotangents , dimension = axis ),
4542+
4543+ def _split_batch_rule (batched_args , batch_dims , * , sizes , axis ):
4544+ operand , = batched_args
4545+ bdim , = batch_dims
4546+ new_bdims = (bdim ,) * len (sizes )
4547+ out = split (operand , sizes = sizes , axis = axis + 1 if axis >= bdim else axis )
4548+ return out , new_bdims
4549+
4550+ def _split_lower (ctx , x , * , sizes , axis ):
4551+ x_aval , = ctx .avals_in
4552+ start_indices = [0 ] * x_aval .ndim
4553+ limit_indices = list (x_aval .shape )
4554+ strides = (1 ,) * x_aval .ndim
4555+ outs = []
4556+ for aval_out in ctx .avals_out :
4557+ limit_indices [axis ] = start_indices [axis ] + aval_out .shape [axis ]
4558+ out = mlir .slice_op (ctx , x , aval_out , start_indices = start_indices ,
4559+ limit_indices = limit_indices , strides = strides )
4560+ outs .append (mlir .lower_sharding_under_shit (ctx , out , aval_out )
4561+ if config .sharding_in_types .value else out )
4562+ start_indices [axis ] = limit_indices [axis ]
4563+ return outs
4564+
4565+ def _split_sharding_rule (operand , * , sizes , axis ):
4566+ # TODO(yashkatariya): Once JAX supports uneven sharding at the top level,
4567+ # change this logic to `return operand.sharding` directly.
4568+ out_shapes = _split_shape_rule (operand , sizes = sizes , axis = axis )
4569+ return [slicing ._get_sharding_for_varying_out_shape (out_sh , operand , 'split' )
4570+ for out_sh in out_shapes ]
4571+
4572+ split_p = core .Primitive ('split' )
4573+ split_p .multiple_results = True
4574+ split_p .def_abstract_eval (
4575+ partial (standard_multi_result_abstract_eval , split_p , _split_shape_rule ,
4576+ _split_dtype_rule , _split_weak_type_rule , _split_sharding_rule ))
4577+ split_p .def_impl (partial (dispatch .apply_primitive , split_p ))
4578+ ad .deflinear2 (split_p , _split_transpose_rule )
4579+ batching .primitive_batchers [split_p ] = _split_batch_rule
4580+ mlir .register_lowering (split_p , _split_lower )
4581+
45024582def _pad_dtype_rule (operand , padding_value , * , padding_config ):
45034583 if operand .dtype != padding_value .dtype :
45044584 msg = "pad operand and padding_value must be same dtype: got {} and {}."
0 commit comments