4242 _input_dtype ,
4343 standard_primitive ,
4444)
45- from jax ._src .sharding_impls import NamedSharding , PartitionSpec as P
4645from jax ._src .lib .mlir import ir
4746from jax ._src .lib .mlir .dialects import hlo
4847from jax ._src .typing import Array , ArrayLike , Shape
@@ -1276,23 +1275,33 @@ def _get_sub_spec_size(mesh, sub_spec):
12761275 return math .prod (mesh .shape [s ] for s in sub_spec )
12771276 return mesh .shape [sub_spec ]
12781277
1279- def _slice_sharding_rule (operand , * , start_indices , limit_indices , strides ):
1280- # TODO(yashkatariya): Once JAX supports uneven sharding at the top level,
1281- # change this logic to `return operand.sharding` directly.
1282- out_shape = _slice_shape_rule (operand , start_indices = start_indices ,
1283- limit_indices = limit_indices , strides = strides )
1278+ def _get_sharding_for_varying_out_shape (out_shape , operand , name ):
1279+ """Returns a sharding when out_shape may not be the same as operand shape"""
12841280 mesh = operand .sharding .mesh
1285- new_spec = []
12861281 for op_sh , out_sh , op_spec in safe_zip (
12871282 operand .shape , out_shape , operand .sharding .spec ):
12881283 if (op_sh != out_sh and op_spec is not None and
12891284 out_sh % _get_sub_spec_size (mesh , op_spec ) != 0 ):
12901285 raise NotImplementedError (
1291- f"slicing on sharded dims where out dim ({ out_sh } ) is not divisble by"
1286+ f"{ name } on sharded dims where out dim ({ out_sh } ) is not divisble by"
12921287 f" mesh axes ({ _get_sub_spec_size (mesh , op_spec )} ) with spec"
12931288 f" ({ op_spec } ) is not implemented." )
1294- new_spec .append (op_spec )
1295- return NamedSharding (mesh , P (* new_spec ))
1289+ # TODO(yashkatariya): Returning operand.sharding as is may or may not move
1290+ # data. So think about how to avoid it which might include creating a new
1291+ # mesh? For example:
1292+ # mesh = {'x': 4}
1293+ # x = jax.device_put(jnp.arange(8), NamedSharding(mesh, P('x')))`
1294+ # ys = lax.split(x, [4, 4]) # This will create outputs of shape (4,)
1295+ # According to the current logic, ys[0].sharding.spec == P('x')
1296+ # which involves data movement.
1297+ return operand .sharding
1298+
1299+ def _slice_sharding_rule (operand , * , start_indices , limit_indices , strides ):
1300+ # TODO(yashkatariya): Once JAX supports uneven sharding at the top level,
1301+ # change this logic to `return operand.sharding` directly.
1302+ out_shape = _slice_shape_rule (operand , start_indices = start_indices ,
1303+ limit_indices = limit_indices , strides = strides )
1304+ return _get_sharding_for_varying_out_shape (out_shape , operand , 'slicing' )
12961305
12971306def _slice_transpose_rule (t , operand , * , start_indices , limit_indices , strides ):
12981307 assert ad .is_undefined_primal (operand )
@@ -1367,8 +1376,7 @@ def _slice_lower(ctx, x, *, start_indices, limit_indices, strides):
13671376mlir .register_lowering (slice_p , _slice_lower )
13681377
13691378
1370- def _dynamic_slice_shape_rule (
1371- operand , * starts_and_dyn_sizes , slice_sizes ):
1379+ def _dynamic_slice_shape_rule (operand , * starts_and_dyn_sizes , slice_sizes ):
13721380 start_indices , dyn = util .split_list (starts_and_dyn_sizes , [operand .ndim ])
13731381 if operand .ndim != len (start_indices ):
13741382 msg = ("dynamic_slice start_indices must have length equal to the number "
@@ -1391,6 +1399,12 @@ def _dynamic_slice_shape_rule(
13911399 f" got indices { start_indices } " )
13921400 return tuple (lax ._merge_dyn_shape (slice_sizes , dyn ))
13931401
1402+ def _dynamic_slice_sharding_rule (operand , * starts_and_dyn_sizes , slice_sizes ):
1403+ out_shape = _dynamic_slice_shape_rule (
1404+ operand , * starts_and_dyn_sizes , slice_sizes = slice_sizes )
1405+ return _get_sharding_for_varying_out_shape (out_shape , operand , 'dynamic_slice' )
1406+
1407+
13941408def _dynamic_slice_dtype_rule (operand , * starts_and_dyn_sizes , slice_sizes ):
13951409 start_indices , dyn = util .split_list (starts_and_dyn_sizes , [operand .ndim ])
13961410 if any (i .dtype != start_indices [0 ].dtype or
@@ -1494,7 +1508,8 @@ def _dynamic_slice_padding_rule(in_avals, out_avals, x, *starts_and_dyn,
14941508
14951509dynamic_slice_p = standard_primitive (
14961510 _dynamic_slice_shape_rule , _dynamic_slice_dtype_rule , 'dynamic_slice' ,
1497- weak_type_rule = _argnum_weak_type (0 ))
1511+ weak_type_rule = _argnum_weak_type (0 ),
1512+ sharding_rule = _dynamic_slice_sharding_rule )
14981513ad .primitive_jvps [dynamic_slice_p ] = _dynamic_slice_jvp
14991514ad .primitive_transposes [dynamic_slice_p ] = _dynamic_slice_transpose_rule
15001515batching .primitive_batchers [dynamic_slice_p ] = _dynamic_slice_batching_rule
@@ -1508,7 +1523,10 @@ def _dynamic_slice_lower(ctx, x, *starts_and_dyn_sizes, slice_sizes):
15081523 aval_out , = ctx .avals_out
15091524 if dyn :
15101525 aval_out = aval_out .update (shape = lax ._merge_dyn_shape (slice_sizes , dyn ))
1511- return [mlir .dynamic_slice (ctx , aval_out , x , start_indices = start_indices )]
1526+ out = mlir .dynamic_slice (ctx , aval_out , x , start_indices = start_indices )
1527+ if config .sharding_in_types .value :
1528+ return [mlir .lower_sharding_under_shit (ctx , out , aval_out )]
1529+ return [out ]
15121530
15131531mlir .register_lowering (dynamic_slice_p , _dynamic_slice_lower )
15141532
@@ -1539,6 +1557,14 @@ def _dynamic_update_slice_shape_rule(operand, update, *start_indices):
15391557 f"scalars, got indices { start_indices } " )
15401558 return operand .shape
15411559
1560+ def _dynamic_update_slice_sharding_rule (operand , update , * start_indices ):
1561+ if operand .sharding != update .sharding :
1562+ raise TypeError (
1563+ "dynamic_update_slice update sharding must be equal to operand"
1564+ f" sharding, got update sharding { update .sharding } for operand sharding"
1565+ f" { operand .sharding } ." )
1566+ return operand .sharding
1567+
15421568def _dynamic_update_slice_dtype_rule (operand , update , * start_indices ):
15431569 lax .check_same_dtypes ("dynamic_update_slice" , operand , update )
15441570 if any (i .dtype != start_indices [0 ].dtype or
@@ -1604,7 +1630,7 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims):
16041630
16051631dynamic_update_slice_p = standard_primitive (
16061632 _dynamic_update_slice_shape_rule , _dynamic_update_slice_dtype_rule ,
1607- 'dynamic_update_slice' )
1633+ 'dynamic_update_slice' , sharding_rule = _dynamic_update_slice_sharding_rule )
16081634ad .primitive_jvps [dynamic_update_slice_p ] = _dynamic_update_slice_jvp
16091635ad .primitive_transposes [dynamic_update_slice_p ] = \
16101636 _dynamic_update_slice_transpose_rule
@@ -1613,8 +1639,11 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims):
16131639
16141640def _dynamic_update_slice_lower (ctx , x , update , * start_indices ):
16151641 aval_out , = ctx .avals_out
1616- return [mlir .dynamic_update_slice (ctx , aval_out , x , update ,
1617- start_indices = start_indices )]
1642+ out = mlir .dynamic_update_slice (ctx , aval_out , x , update ,
1643+ start_indices = start_indices )
1644+ if config .sharding_in_types .value :
1645+ return [mlir .lower_sharding_under_shit (ctx , out , aval_out )]
1646+ return [out ]
16181647
16191648mlir .register_lowering (dynamic_update_slice_p , _dynamic_update_slice_lower )
16201649
0 commit comments