Skip to content

Commit 7de9eb2

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Reverts 525b646
PiperOrigin-RevId: 707146329
1 parent 4911a39 commit 7de9eb2

File tree

13 files changed

+237
-31
lines changed

13 files changed

+237
-31
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
3232
* {func}`jax.export.export` can be used for device-polymorphic export with
3333
shardings constructed with {func}`jax.sharding.AbstractMesh`.
3434
See the [jax.export documentation](https://jax.readthedocs.io/en/latest/export/export.html#device-polymorphic-export).
35+
* Added {func}`jax.lax.split`. This is a primitive version of
36+
{func}`jax.numpy.split`, added because it yields a more compact
37+
transpose during automatic differentiation.
3538

3639
## jax 0.4.37 (Dec 9, 2024)
3740

docs/jax.lax.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ Operators
154154
slice_in_dim
155155
sort
156156
sort_key_val
157+
split
157158
sqrt
158159
square
159160
squeeze

jax/_src/lax/lax.py

Lines changed: 92 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,26 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array:
673673
return concatenate_p.bind(*operands, dimension=dimension)
674674

675675

676+
def split(operand: ArrayLike, sizes: Sequence[int],
677+
axis: int = 0) -> Sequence[Array]:
678+
"""Splits an array along ``axis``.
679+
680+
Args:
681+
operand: an array to split
682+
sizes: the sizes of the split arrays. The sum of the sizes must be equal
683+
to the size of the ``axis`` dimension of ``operand``.
684+
axis: the axis along which to split the array.
685+
686+
Returns:
687+
A sequence of ``len(sizes)`` arrays. If ``sizes`` is
688+
``[s1, s2, ...]``, this function returns chunks of sizes ``s1``, ``s2``,
689+
taken along ``axis``.
690+
"""
691+
operand = asarray(operand)
692+
return split_p.bind(operand, sizes=tuple(sizes),
693+
axis=canonicalize_axis(axis, operand.ndim))
694+
695+
676696
_precision_strings: dict[Any, Precision] = {}
677697

678698
class Precision(enum.Enum):
@@ -4454,18 +4474,8 @@ def _concatenate_transpose_rule(t, *operands, dimension):
44544474
return [ad_util.Zero(o.aval) if ad.is_undefined_primal(o) else None
44554475
for o in operands]
44564476
else:
4457-
limit_points = np.cumsum(
4458-
[shape[dimension] for shape in operand_shapes]).tolist()
4459-
starts = np.zeros((len(operands), t.ndim), dtype=int).tolist()
4460-
limits = np.tile(t.shape, (len(operands), 1)).tolist()
4461-
4462-
for i, s in enumerate(starts[1:]):
4463-
s[dimension] = limit_points[:-1][i]
4464-
for i, l in enumerate(limits):
4465-
l[dimension] = limit_points[i]
4466-
4467-
return [slicing.slice(t, start, limit) if ad.is_undefined_primal(o)
4468-
else None for o, start, limit in zip(operands, starts, limits)]
4477+
return split(t, tuple(shape[dimension] for shape in operand_shapes),
4478+
axis=dimension)
44694479

44704480
def _concatenate_batch_rule(batched_args, batch_dims, *, dimension):
44714481
size = next(op.shape[bdim] for op, bdim in zip(batched_args, batch_dims)
@@ -4499,6 +4509,76 @@ def _concatenate_lower(ctx, *xs, dimension):
44994509
mlir.register_lowering(concatenate_p, _concatenate_lower)
45004510

45014511

4512+
def _split_shape_rule(operand, *, sizes, axis):
4513+
shapes = []
4514+
shape = list(operand.shape)
4515+
if any(s < 0 for s in sizes):
4516+
raise ValueError(
4517+
f"Sizes passed to split must be nonnegative, got {list(sizes)}")
4518+
if operand.shape[axis] != np.sum(sizes):
4519+
raise ValueError(
4520+
f"Sum of sizes {np.sum(sizes)} must be equal to dimension {axis} of the "
4521+
f"operand shape {list(operand.shape)}")
4522+
for size in sizes:
4523+
shape[axis] = size
4524+
shapes.append(tuple(shape))
4525+
return shapes
4526+
4527+
def _split_dtype_rule(operand, *, sizes, axis):
4528+
return (operand.dtype,) * len(sizes)
4529+
4530+
def _split_weak_type_rule(operand, *, sizes, axis):
4531+
return (operand.weak_type,) * len(sizes)
4532+
4533+
def _split_transpose_rule(cotangents, operand, *, sizes, axis):
4534+
assert ad.is_undefined_primal(operand)
4535+
if all(type(t) is ad_util.Zero for t in cotangents):
4536+
return ad_util.Zero(operand.aval),
4537+
cotangents = [
4538+
_zeros(t.aval) if type(t) is ad_util.Zero else t
4539+
for t in cotangents
4540+
]
4541+
return concatenate(cotangents, dimension=axis),
4542+
4543+
def _split_batch_rule(batched_args, batch_dims, *, sizes, axis):
4544+
operand, = batched_args
4545+
bdim, = batch_dims
4546+
new_bdims = (bdim,) * len(sizes)
4547+
out = split(operand, sizes=sizes, axis=axis + 1 if axis >= bdim else axis)
4548+
return out, new_bdims
4549+
4550+
def _split_lower(ctx, x, *, sizes, axis):
4551+
x_aval, = ctx.avals_in
4552+
start_indices = [0] * x_aval.ndim
4553+
limit_indices = list(x_aval.shape)
4554+
strides = (1,) * x_aval.ndim
4555+
outs = []
4556+
for aval_out in ctx.avals_out:
4557+
limit_indices[axis] = start_indices[axis] + aval_out.shape[axis]
4558+
out = mlir.slice_op(ctx, x, aval_out, start_indices=start_indices,
4559+
limit_indices=limit_indices, strides=strides)
4560+
outs.append(mlir.lower_sharding_under_shit(ctx, out, aval_out)
4561+
if config.sharding_in_types.value else out)
4562+
start_indices[axis] = limit_indices[axis]
4563+
return outs
4564+
4565+
def _split_sharding_rule(operand, *, sizes, axis):
4566+
# TODO(yashkatariya): Once JAX supports uneven sharding at the top level,
4567+
# change this logic to `return operand.sharding` directly.
4568+
out_shapes = _split_shape_rule(operand, sizes=sizes, axis=axis)
4569+
return [slicing._get_sharding_for_varying_out_shape(out_sh, operand, 'split')
4570+
for out_sh in out_shapes]
4571+
4572+
split_p = core.Primitive('split')
4573+
split_p.multiple_results = True
4574+
split_p.def_abstract_eval(
4575+
partial(standard_multi_result_abstract_eval, split_p, _split_shape_rule,
4576+
_split_dtype_rule, _split_weak_type_rule, _split_sharding_rule))
4577+
split_p.def_impl(partial(dispatch.apply_primitive, split_p))
4578+
ad.deflinear2(split_p, _split_transpose_rule)
4579+
batching.primitive_batchers[split_p] = _split_batch_rule
4580+
mlir.register_lowering(split_p, _split_lower)
4581+
45024582
def _pad_dtype_rule(operand, padding_value, *, padding_config):
45034583
if operand.dtype != padding_value.dtype:
45044584
msg = "pad operand and padding_value must be same dtype: got {} and {}."

jax/_src/numpy/array_methods.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,8 @@ def _multi_slice(self: Array,
629629
# avoid circular imports.
630630
@jax.jit
631631
def _unstack(x: Array) -> list[Array]:
632-
return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]
632+
dims = (0,)
633+
return [lax.squeeze(t, dims) for t in lax.split(x, (1,) * x.shape[0])]
633634

634635
def _chunk_iter(x, size):
635636
if size > x.shape[0]:

jax/_src/numpy/lax_numpy.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
)
6969
from jax._src.util import (
7070
NumpyComplexWarning, canonicalize_axis as _canonicalize_axis,
71-
ceil_of_ratio, partition_list, safe_zip, set_module, subvals,unzip2,
71+
ceil_of_ratio, partition_list, safe_zip, set_module, unzip2,
7272
tuple_replace)
7373
from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding,
7474
PartitionSpec as P)
@@ -3273,10 +3273,10 @@ def _split(op: str, ary: ArrayLike,
32733273
if (isinstance(indices_or_sections, (tuple, list)) or
32743274
isinstance(indices_or_sections, (np.ndarray, Array)) and
32753275
indices_or_sections.ndim > 0):
3276-
indices_or_sections = [
3276+
split_indices = np.asarray([0] + [
32773277
core.concrete_dim_or_error(i_s, f"in jax.numpy.{op} argument 1")
3278-
for i_s in indices_or_sections]
3279-
split_indices = [0] + list(indices_or_sections) + [size]
3278+
for i_s in indices_or_sections] + [size])
3279+
sizes = list(np.diff(split_indices))
32803280
else:
32813281
if core.is_symbolic_dim(indices_or_sections):
32823282
raise ValueError(f"jax.numpy.{op} with a symbolic number of sections is "
@@ -3285,21 +3285,14 @@ def _split(op: str, ary: ArrayLike,
32853285
f"in jax.numpy.{op} argument 1")
32863286
part_size, r = divmod(size, num_sections)
32873287
if r == 0:
3288-
split_indices = [i * part_size
3289-
for i in range(num_sections + 1)]
3288+
sizes = [part_size] * num_sections
32903289
elif op == "array_split":
3291-
split_indices = (
3292-
[i * (part_size + 1) for i in range(r + 1)] +
3293-
[i * part_size + ((r + 1) * (part_size + 1) - 1)
3294-
for i in range(num_sections - r)])
3290+
sizes = [(part_size + 1)] * r + [part_size] * (num_sections - r)
32953291
else:
32963292
raise ValueError(f"array split does not result in an equal division: rest is {r}")
3297-
split_indices = [i if core.is_symbolic_dim(i) else np.int64(i) # type: ignore[misc]
3298-
for i in split_indices]
3299-
starts, ends = [0] * ndim(ary), shape(ary)
3300-
_subval = lambda x, i, v: subvals(x, [(i, v)])
3301-
return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end))
3302-
for start, end in zip(split_indices[:-1], split_indices[1:])]
3293+
sizes = [i if core.is_symbolic_dim(i) else np.int64(i) # type: ignore[misc]
3294+
for i in sizes]
3295+
return list(lax.split(ary, sizes, axis=axis))
33033296

33043297

33053298
@export
@@ -4662,7 +4655,11 @@ def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]:
46624655
"Unstack requires arrays with rank > 0, however a scalar array was "
46634656
"passed."
46644657
)
4665-
return tuple(moveaxis(x, axis, 0))
4658+
dimensions = (axis,)
4659+
return tuple(
4660+
lax.squeeze(t, dimensions)
4661+
for t in lax.split(x, (1,) * x.shape[axis], axis=axis)
4662+
)
46664663

46674664

46684665
@export

jax/_src/pallas/mosaic/lowering.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1901,6 +1901,27 @@ def _concatenate_lowering_rule(ctx: LoweringRuleContext, *xs, dimension):
19011901
lowering_rules[lax.concatenate_p] = _concatenate_lowering_rule
19021902

19031903

1904+
def _split_lowering_rule(
1905+
ctx: LoweringRuleContext, x, *, sizes, axis
1906+
):
1907+
(x_aval,) = ctx.avals_in
1908+
slice_size = np.array(x_aval.shape, dtype=np.int64)
1909+
starts = np.zeros_like(slice_size)
1910+
strides = np.ones_like(slice_size)
1911+
outs = []
1912+
for size, aval_out in zip(sizes, ctx.avals_out):
1913+
slice_size[axis] = size
1914+
outs.append(
1915+
vector.extract_strided_slice(
1916+
aval_to_ir_type(aval_out), x, starts, slice_size, strides
1917+
)
1918+
)
1919+
starts[axis] += size
1920+
return outs
1921+
1922+
lowering_rules[lax.split_p] = _split_lowering_rule
1923+
1924+
19041925
def _iota_lowering_rule(ctx: LoweringRuleContext, dtype, shape, dimension,
19051926
sharding):
19061927
out_type = aval_to_ir_type(ctx.avals_out[0])

jax/experimental/jax2tf/jax2tf.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2087,6 +2087,12 @@ def _concatenate(*operands, dimension):
20872087
tf_impl[lax.concatenate_p] = _concatenate
20882088

20892089

2090+
def _split(operand, *, sizes, axis):
2091+
return tf.split(operand, _eval_shape(sizes), axis=axis)
2092+
2093+
tf_impl[lax.split_p] = _split
2094+
2095+
20902096
def _conv_general_dimension_numbers_proto(dimension_numbers):
20912097
"""Converts a ConvDimensionNumbers to an XLA ConvolutionDimensionNumbers."""
20922098
assert isinstance(dimension_numbers, lax.ConvDimensionNumbers)

jax/experimental/jet.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
from jax._src.api_util import shaped_abstractify
7474
from jax._src.interpreters import partial_eval as pe
7575
from jax._src.lax import lax as lax_internal
76-
from jax._src.util import unzip2, weakref_lru_cache
76+
from jax._src.util import unzip2, weakref_lru_cache, safe_zip
7777

7878

7979
def jet(fun, primals, series):
@@ -310,6 +310,8 @@ def deflinear(prim):
310310
def linear_prop(prim, primals_in, series_in, **params):
311311
primal_out = prim.bind(*primals_in, **params)
312312
series_out = [prim.bind(*terms_in, **params) for terms_in in zip(*series_in)]
313+
if prim.multiple_results:
314+
series_out = safe_zip(*series_out)
313315
return primal_out, series_out
314316

315317
deflinear(lax.neg_p)
@@ -323,6 +325,7 @@ def linear_prop(prim, primals_in, series_in, **params):
323325
deflinear(lax.convert_element_type_p)
324326
deflinear(lax.broadcast_in_dim_p)
325327
deflinear(lax.concatenate_p)
328+
deflinear(lax.split_p)
326329
deflinear(lax.pad_p)
327330
deflinear(lax.reshape_p)
328331
deflinear(lax.squeeze_p)

jax/lax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@
203203
sort as sort,
204204
sort_key_val as sort_key_val,
205205
sort_p as sort_p,
206+
split as split,
207+
split_p as split_p,
206208
sqrt as sqrt,
207209
sqrt_p as sqrt_p,
208210
square as square,

tests/lax_autodiff_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,24 @@ def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs):
276276
concatenate = lambda *args: lax.concatenate(args, dim)
277277
check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.)
278278

279+
@jtu.sample_product(
280+
[dict(base_shape=base_shape, axis=axis)
281+
for base_shape in [(4,), (3, 4), (2, 3, 4)]
282+
for axis in range(len(base_shape))
283+
],
284+
num_pieces=range(3),
285+
dtype=float_dtypes,
286+
)
287+
def testSplitGrad(self, axis, base_shape, dtype, num_pieces):
288+
sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64)
289+
shape = list(base_shape)
290+
shape[axis] = np.sum(sizes)
291+
rng = jtu.rand_default(self.rng())
292+
operands = (rng(shape, dtype),)
293+
split = lambda x: lax.split(x, sizes, axis)
294+
check_grads(split, operands, 2, ["fwd", "rev"], eps=1.)
295+
296+
279297
@jtu.sample_product(
280298
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, strides=strides)
281299
for lhs_shape, rhs_shape, all_strides in itertools.chain(

0 commit comments

Comments
 (0)