Skip to content

Commit 59e13f8

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Add sharding argument to reshape since it also takes a shape argument for the output shape
PiperOrigin-RevId: 700163883
1 parent c5dc980 commit 59e13f8

File tree

6 files changed

+62
-32
lines changed

6 files changed

+62
-32
lines changed

jax/_src/lax/lax.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,7 +1231,8 @@ def broadcast_to_rank(x: ArrayLike, rank: int) -> Array:
12311231
return broadcast(x, (1,) * (rank - ndim))
12321232

12331233
def reshape(operand: ArrayLike, new_sizes: Shape,
1234-
dimensions: Sequence[int] | None = None) -> Array:
1234+
dimensions: Sequence[int] | None = None,
1235+
sharding: NamedSharding | None = None) -> Array:
12351236
"""Wraps XLA's `Reshape
12361237
<https://www.tensorflow.org/xla/operation_semantics#reshape>`_
12371238
operator.
@@ -1285,7 +1286,8 @@ def reshape(operand: ArrayLike, new_sizes: Shape,
12851286

12861287
return reshape_p.bind(
12871288
operand, *dyn_shape, new_sizes=tuple(static_new_sizes),
1288-
dimensions=None if dims is None or same_dims else dims)
1289+
dimensions=None if dims is None or same_dims else dims,
1290+
sharding=sharding)
12891291

12901292
def pad(operand: ArrayLike, padding_value: ArrayLike,
12911293
padding_config: Sequence[tuple[int, int, int]]) -> Array:
@@ -4654,7 +4656,7 @@ def shape_as_value(shape: core.Shape):
46544656
]
46554657
return concatenate(dims, dimension=0)
46564658

4657-
def _reshape_shape_rule(operand, *, new_sizes, dimensions):
4659+
def _reshape_shape_rule(operand, *, new_sizes, dimensions, sharding):
46584660
if not all(d >= 0 for d in new_sizes):
46594661
msg = 'reshape new_sizes must all be positive, got {}.'
46604662
raise TypeError(msg.format(new_sizes))
@@ -4674,7 +4676,9 @@ def _reshape_shape_rule(operand, *, new_sizes, dimensions):
46744676
raise TypeError(msg.format(dimensions, np.shape(operand)))
46754677
return tuple(new_sizes)
46764678

4677-
def _reshape_sharding_rule(operand, *, new_sizes, dimensions):
4679+
def _reshape_sharding_rule(operand, *, new_sizes, dimensions, sharding):
4680+
if sharding is not None:
4681+
return sharding
46784682
filtered_spec = [
46794683
(sh, sp) for sh, sp in zip(operand.shape, operand.sharding.spec)
46804684
if sh != 1
@@ -4687,14 +4691,18 @@ def _reshape_sharding_rule(operand, *, new_sizes, dimensions):
46874691
else:
46884692
sh, sp = next(fs)
46894693
if n != sh:
4690-
raise NotImplementedError
4694+
raise ValueError(
4695+
'This reshape is not supported. Please specify the sharding of the'
4696+
' output via the `sharding` argument of reshape.')
46914697
new_spec.append(sp)
46924698
return operand.sharding.with_spec(new_spec)
46934699

4694-
def _reshape_typecheck_rule(_, operand, *dyn_shape, new_sizes, dimensions):
4700+
def _reshape_typecheck_rule(_, operand, *dyn_shape, new_sizes, dimensions,
4701+
sharding):
46954702
if not dyn_shape:
46964703
out_aval, effects = reshape_p.abstract_eval(
4697-
operand.aval, new_sizes=new_sizes, dimensions=dimensions)
4704+
operand.aval, new_sizes=new_sizes, dimensions=dimensions,
4705+
sharding=sharding)
46984706
return [out_aval], effects
46994707
else:
47004708
# TODO(mattjj, necula): perform more checks like _reshape_shape_rule
@@ -4705,18 +4713,29 @@ def _reshape_typecheck_rule(_, operand, *dyn_shape, new_sizes, dimensions):
47054713
return [out_aval], core.no_effects
47064714

47074715

4708-
def _reshape_dtype_rule(operand, *, new_sizes, dimensions):
4716+
def _reshape_dtype_rule(operand, *, new_sizes, dimensions, sharding):
47094717
return operand.dtype
47104718

4711-
def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions):
4719+
def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions, sharding):
47124720
assert ad.is_undefined_primal(operand)
47134721
if dimensions is None:
4722+
if config.sharding_in_types.value:
4723+
return [reshape(t, operand.aval.shape, sharding=operand.aval.sharding)]
47144724
return [reshape(t, operand.aval.shape)]
47154725
else:
4716-
return [transpose(reshape(t, np.take(operand.aval.shape, dimensions)),
4726+
if config.sharding_in_types.value:
4727+
t_s = operand.sharding.with_spec(
4728+
tuple(map(str, np.take(operand.aval.sharding.spec, dimensions))))
4729+
else:
4730+
t_s = None
4731+
return [transpose(reshape(t, np.take(operand.aval.shape, dimensions),
4732+
sharding=t_s),
47174733
np.argsort(dimensions))]
47184734

4719-
def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions):
4735+
def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions,
4736+
sharding):
4737+
if sharding is not None:
4738+
raise NotImplementedError
47204739
operand, = batched_args
47214740
bdim, = batch_dims
47224741
operand = batching.moveaxis(operand, bdim, 0)
@@ -4725,20 +4744,22 @@ def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions):
47254744
return reshape(operand, operand.shape[:1] + new_sizes, dimensions), 0
47264745

47274746

4728-
def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions):
4747+
def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions, sharding):
47294748
aval_out, = ctx.avals_out
47304749
if dimensions is not None:
47314750
x = hlo.transpose(x, mlir.dense_int_array(dimensions))
47324751
if dyn_shape:
47334752
aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape))
47344753
out = mlir.reshape(ctx, x, aval_out)
47354754
if config.sharding_in_types.value:
4755+
if sharding is not None:
4756+
assert sharding == aval_out.sharding
47364757
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
47374758
return [out]
47384759

47394760
def _reshape_staging_rule(
4740-
trace, x, *dyn, new_sizes, dimensions):
4741-
params = dict(new_sizes=new_sizes, dimensions=dimensions)
4761+
trace, x, *dyn, new_sizes, dimensions, sharding):
4762+
params = dict(new_sizes=new_sizes, dimensions=dimensions, sharding=sharding)
47424763
if not dyn:
47434764
return trace.default_process_primitive(reshape_p, (x,), params)
47444765
av = core.DShapedArray(_merge_dyn_shape(new_sizes, dyn), x.dtype, x.weak_type)

jax/_src/pallas/mosaic/lowering.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1849,7 +1849,8 @@ def _convert_element_type_lowering_rule(
18491849
lowering_rules[lax.convert_element_type_p] = _convert_element_type_lowering_rule
18501850

18511851

1852-
def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions):
1852+
def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions,
1853+
sharding):
18531854
if dimensions is not None:
18541855
raise NotImplementedError
18551856
if any(d is None for d in new_sizes):

jax/_src/pallas/triton/lowering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1612,7 +1612,7 @@ def _squeeze_lowering_rule(ctx: LoweringRuleContext, a, *, dimensions):
16121612

16131613
@register_lowering(lax.reshape_p)
16141614
def _reshape_lowering_rule(
1615-
ctx: LoweringRuleContext, a, *, new_sizes, dimensions
1615+
ctx: LoweringRuleContext, a, *, new_sizes, dimensions, sharding,
16161616
):
16171617
del new_sizes # Unused.
16181618
if dimensions is not None:

jax/experimental/jax2tf/jax2tf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2291,7 +2291,7 @@ def _empty(*, dtype):
22912291
tf_impl[lax_internal.empty_p] = _empty
22922292

22932293

2294-
def _reshape(operand, *, new_sizes, dimensions, _in_avals, _out_aval):
2294+
def _reshape(operand, *, new_sizes, dimensions, sharding, _in_avals, _out_aval):
22952295
if dimensions is None:
22962296
dimensions = tf.range(tf.rank(operand))
22972297
new_sizes_tf = _eval_shape(new_sizes, _in_avals[0].dtype)

jax/experimental/sparse/bcoo.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1826,7 +1826,9 @@ def bcoo_concatenate(operands: Sequence[BCOO], *, dimension: int) -> BCOO:
18261826
return BCOO((new_data, new_indices), shape=out_aval.shape)
18271827

18281828

1829-
def bcoo_reshape(mat: BCOO, *, new_sizes: Sequence[int], dimensions: Sequence[int] | None = None) -> BCOO:
1829+
def bcoo_reshape(mat: BCOO, *, new_sizes: Sequence[int],
1830+
dimensions: Sequence[int] | None = None,
1831+
sharding=None) -> BCOO:
18301832
"""Sparse implementation of {func}`jax.lax.reshape`.
18311833
18321834
Args:

tests/pjit_test.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5147,31 +5147,37 @@ def h2(x, y):
51475147
@parameterized.named_parameters(
51485148
('1', (16, 1), (1, 16, 1), P('x', None), P(None, 'x', None), False),
51495149
('2', (8, 2, 1), (1, 16, 1), P('x', None, None), P(None, 'x', None), True),
5150-
('3', (8, 1), (1, 4, 2), P('x', None), P(None, 'x', None), True)
5150+
('3', (8, 1), (1, 4, 2), P('x', None), P(None, None, 'x'), True)
51515151
)
5152-
def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec, will_error):
5152+
def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec,
5153+
use_sharding_arg):
51535154
mesh = jtu.create_mesh((2,), ('x',))
51545155
np_inp = np.arange(math.prod(src_shape),
51555156
dtype=np.float32).reshape(src_shape)
51565157
arr = jax.device_put(np_inp, NamedSharding(mesh, src_spec))
51575158

5158-
@jax.jit
5159-
def f(x):
5160-
y = jnp.reshape(x, dst_shape)
5159+
@partial(jax.jit, static_argnums=1)
5160+
def f(x, new_sharding):
5161+
y = lax.reshape(x, dst_shape, sharding=new_sharding)
51615162
y = y * 2
51625163
self.assertEqual(y.sharding.spec, dst_spec)
51635164
return y
51645165

5165-
if will_error:
5166-
with self.assertRaises(NotImplementedError):
5167-
f(arr)
5168-
else:
5169-
out = f(arr)
5170-
self.assertEqual(out.sharding, NamedSharding(mesh, dst_spec))
5171-
self.assertArraysEqual(out, np_inp.reshape(dst_shape) * 2)
5166+
new_s = (NamedSharding(mesh.abstract_mesh, dst_spec)
5167+
if use_sharding_arg else None)
5168+
out = f(arr, new_s)
5169+
self.assertEqual(out.sharding, NamedSharding(mesh, dst_spec))
5170+
self.assertArraysEqual(out, np_inp.reshape(dst_shape) * 2)
5171+
5172+
lowered_text = f.lower(arr, new_s).as_text()
5173+
self.assertIn('@Sharding', lowered_text)
51725174

5173-
lowered_text = f.lower(arr).as_text()
5174-
self.assertIn('@Sharding', lowered_text)
5175+
def g(x):
5176+
out = f(x, new_s)
5177+
return jnp.square(jnp.sum(out))
5178+
5179+
out = jax.jit(jax.grad(g))(arr)
5180+
self.assertEqual(out.sharding, arr.sharding)
51755181

51765182
def test_select(self):
51775183
mesh = jtu.create_mesh((2, 2), ('x', 'y'))

0 commit comments

Comments
 (0)