Skip to content

Commit 355589f

Browse files
yashk2810Google-ML-Automation
authored andcommitted
[sharding_in_types] Add scan support to sharding_in_types. There are a couple of changes here
* Set abstract_mesh context manager during pjit_p.bind at the top level too since scan builds jaxpr during it's lowering in `_scan_impl` (do the same for AOT path) * Set the abstract mesh only once if it's not set. Don't override an already set context. This means that only top level jit sets the context manager. * Add dynamic_slice and dynamic_update_slice sharding rules since scan calls into them. * scan only allows `xs` where the 0th dim is full replicated i.e. None. PiperOrigin-RevId: 699014167
1 parent 3d79df2 commit 355589f

File tree

8 files changed

+153
-62
lines changed

8 files changed

+153
-62
lines changed

jax/_src/core.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2263,16 +2263,20 @@ def _map_shaped_array(
22632263
assert axis is None or aval.shape[axis] == size
22642264
# TODO: Extend the named shape
22652265
if axis is None: return aval
2266+
sharding = (aval.sharding.with_spec(tuple_delete(aval.sharding.spec, axis))
2267+
if config.sharding_in_types.value else None)
22662268
return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype,
2267-
weak_type=aval.weak_type)
2269+
weak_type=aval.weak_type, sharding=sharding)
22682270

22692271
def _unmap_shaped_array(
22702272
size: int, axis_name: AxisName, axis: int | None, aval: ShapedArray
22712273
) -> ShapedArray:
22722274
if axis is None: return aval
22732275
elif type(axis) is int:
2276+
sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, axis, axis_name))
2277+
if config.sharding_in_types.value else None)
22742278
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
2275-
weak_type=aval.weak_type)
2279+
weak_type=aval.weak_type, sharding=sharding)
22762280
else: raise TypeError(axis)
22772281

22782282
def _map_dshaped_array(

jax/_src/lax/control_flow/loops.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,11 @@ def scan(f, init, xs, length=None):
227227
msg.format(', '.join(str(x) for x in xs_flat
228228
if not hasattr(x, 'shape')))) from err
229229

230+
if (config.sharding_in_types.value and
231+
not all(x.sharding.spec[0] is None for x in xs_flat)):
232+
raise ValueError('0th dimension of all xs should be replicated. Got '
233+
f'{", ".join(str(x.sharding.spec) for x in xs_flat)}')
234+
230235
if length is not None:
231236
try:
232237
length = int(length)
@@ -250,7 +255,8 @@ def scan(f, init, xs, length=None):
250255

251256
if config.disable_jit.value:
252257
if length == 0:
253-
raise ValueError("zero-length scan is not supported in disable_jit() mode because the output type is unknown.")
258+
raise ValueError("zero-length scan is not supported in disable_jit() "
259+
"mode because the output type is unknown.")
254260
carry = init
255261
ys = []
256262
maybe_reversed = reversed if reverse else lambda x: x
@@ -424,15 +430,15 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
424430
num_trips, remainder = 0, length
425431
if unroll == 1:
426432
xss = xs_
427-
yss = _map(partial(_empty_array, (length,)), y_avals)
433+
yss = _map(partial(_empty_array, (length,), None), y_avals)
428434
else:
429435
if remainder:
430436
if not reverse:
431437
xs_, xs_rem = unzip2(_map(partial(_split_leading, num_trips*unroll), xs_))
432438
else:
433439
xs_rem, xs_ = unzip2(_map(partial(_split_leading, remainder), xs_))
434440
xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_]
435-
yss = _map(partial(_empty_array, (num_trips, unroll)), y_avals)
441+
yss = _map(partial(_empty_array, (num_trips, unroll), None), y_avals)
436442

437443
def cond_fun(while_carry):
438444
i, _, _ = while_carry
@@ -477,20 +483,25 @@ def _split_leading(sz, x):
477483

478484
def _concat(a, b): return lax.concatenate([a, b], 0)
479485

480-
def _empty_array(prefix, aval):
481-
return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape))
486+
def _empty_array(prefix, length_spec, aval):
487+
sharding = (aval.sharding.with_spec((length_spec, *aval.sharding.spec))
488+
if config.sharding_in_types.value else None)
489+
return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape),
490+
sharding=sharding)
482491

483492
eval_jaxpr_p = core.Primitive('eval_jaxpr')
484493
eval_jaxpr_p.multiple_results = True
485494
def _stage_jaxpr(trace, *tracers, jaxpr):
486495
params = dict(call_jaxpr=jaxpr)
487496
return trace.default_process_primitive(core.closed_call_p, tracers, params)
488497
pe.custom_staging_rules[eval_jaxpr_p] = _stage_jaxpr
498+
489499
@eval_jaxpr_p.def_effectful_abstract_eval # abstract eval only used for jax2tf
490-
def _stage_jaxpr_abstract_eval(*_, jaxpr): return jaxpr.out_avals, jaxpr.effects
500+
def _stage_jaxpr_abstract_eval(*_, jaxpr):
501+
return jaxpr.out_avals, jaxpr.effects
491502

492503
def _prepend_dim_to_aval(sz, aval):
493-
return core.unmapped_aval(sz, core.no_axis_name, 0, aval)
504+
return core.unmapped_aval(sz, None, 0, aval)
494505

495506
def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr,
496507
linear, unroll, _split_transpose):
@@ -674,7 +685,7 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
674685
extensive_res = _map(trace.new_instantiated_const, extensive_res)
675686
# Create output tracers for jaxpr_unknown bind, adapting extensive shapes.
676687
carry_avals, y_avals = split_list(jaxpr_unknown.out_avals, [sum(carry_uk)])
677-
ys_avals = [core.unmapped_aval(length, core.no_axis_name, 0, y_aval)
688+
ys_avals = [core.unmapped_aval(length, None, 0, y_aval)
678689
for y_aval in y_avals]
679690
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
680691
for a in itertools.chain(carry_avals, ys_avals)]
@@ -1041,7 +1052,7 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
10411052

10421053
# Create residual variables.
10431054
intensive_avals, ext_avals_mapped = partition_list(loop_dep_res, res_avals)
1044-
ext_avals = [core.unmapped_aval(eqn.params['length'], core.no_axis_name, 0, a)
1055+
ext_avals = [core.unmapped_aval(eqn.params['length'], None, 0, a)
10451056
for a in ext_avals_mapped]
10461057
newvar = core.gensym()
10471058
intensive_res = _map(newvar, intensive_avals)
@@ -1119,7 +1130,7 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts,
11191130
jaxpr.in_avals, [num_consts, num_carry])
11201131
carry_avals_jaxpr, y_avals_mapped = split_list(jaxpr.out_avals, [num_carry])
11211132
x_avals_mapped = _map(partial(core.mapped_aval, length, 0), x_avals)
1122-
y_avals = [core.unmapped_aval(length, core.no_axis_name, 0, a)
1133+
y_avals = [core.unmapped_aval(length, None, 0, a)
11231134
for a in y_avals_mapped]
11241135

11251136
if not all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)):

jax/_src/lax/lax.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4513,18 +4513,8 @@ def _pad_sharding_rule(operand, padding_value, *, padding_config):
45134513
# change this logic to `return operand.sharding` directly.
45144514
out_shape = _pad_shape_rule(operand, padding_value,
45154515
padding_config=padding_config)
4516-
mesh = operand.sharding.mesh
4517-
new_spec = []
4518-
for op_sh, out_sh, op_spec in safe_zip(
4519-
operand.shape, out_shape, operand.sharding.spec):
4520-
if (op_sh != out_sh and op_spec is not None and
4521-
out_sh % slicing._get_sub_spec_size(mesh, op_spec) != 0):
4522-
raise NotImplementedError(
4523-
f"padding on sharded dims where out dim ({out_sh}) is not divisble by"
4524-
f" mesh axes ({slicing._get_sub_spec_size(mesh, op_spec)}) with spec"
4525-
f" ({op_spec}) is not implemented.")
4526-
new_spec.append(op_spec)
4527-
return NamedSharding(mesh, P(*new_spec))
4516+
return slicing._get_sharding_for_varying_out_shape(
4517+
out_shape, operand, 'padding')
45284518

45294519

45304520
def _pad_transpose(t, operand, padding_value, *, padding_config):

jax/_src/lax/slicing.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
_input_dtype,
4343
standard_primitive,
4444
)
45-
from jax._src.sharding_impls import NamedSharding, PartitionSpec as P
4645
from jax._src.lib.mlir import ir
4746
from jax._src.lib.mlir.dialects import hlo
4847
from jax._src.typing import Array, ArrayLike, Shape
@@ -1276,23 +1275,33 @@ def _get_sub_spec_size(mesh, sub_spec):
12761275
return math.prod(mesh.shape[s] for s in sub_spec)
12771276
return mesh.shape[sub_spec]
12781277

1279-
def _slice_sharding_rule(operand, *, start_indices, limit_indices, strides):
1280-
# TODO(yashkatariya): Once JAX supports uneven sharding at the top level,
1281-
# change this logic to `return operand.sharding` directly.
1282-
out_shape = _slice_shape_rule(operand, start_indices=start_indices,
1283-
limit_indices=limit_indices, strides=strides)
1278+
def _get_sharding_for_varying_out_shape(out_shape, operand, name):
1279+
"""Returns a sharding when out_shape may not be the same as operand shape"""
12841280
mesh = operand.sharding.mesh
1285-
new_spec = []
12861281
for op_sh, out_sh, op_spec in safe_zip(
12871282
operand.shape, out_shape, operand.sharding.spec):
12881283
if (op_sh != out_sh and op_spec is not None and
12891284
out_sh % _get_sub_spec_size(mesh, op_spec) != 0):
12901285
raise NotImplementedError(
1291-
f"slicing on sharded dims where out dim ({out_sh}) is not divisble by"
1286+
f"{name} on sharded dims where out dim ({out_sh}) is not divisble by"
12921287
f" mesh axes ({_get_sub_spec_size(mesh, op_spec)}) with spec"
12931288
f" ({op_spec}) is not implemented.")
1294-
new_spec.append(op_spec)
1295-
return NamedSharding(mesh, P(*new_spec))
1289+
# TODO(yashkatariya): Returning operand.sharding as is may or may not move
1290+
# data. So think about how to avoid it which might include creating a new
1291+
# mesh? For example:
1292+
# mesh = {'x': 4}
1293+
# x = jax.device_put(jnp.arange(8), NamedSharding(mesh, P('x')))`
1294+
# ys = lax.split(x, [4, 4]) # This will create outputs of shape (4,)
1295+
# According to the current logic, ys[0].sharding.spec == P('x')
1296+
# which involves data movement.
1297+
return operand.sharding
1298+
1299+
def _slice_sharding_rule(operand, *, start_indices, limit_indices, strides):
1300+
# TODO(yashkatariya): Once JAX supports uneven sharding at the top level,
1301+
# change this logic to `return operand.sharding` directly.
1302+
out_shape = _slice_shape_rule(operand, start_indices=start_indices,
1303+
limit_indices=limit_indices, strides=strides)
1304+
return _get_sharding_for_varying_out_shape(out_shape, operand, 'slicing')
12961305

12971306
def _slice_transpose_rule(t, operand, *, start_indices, limit_indices, strides):
12981307
assert ad.is_undefined_primal(operand)
@@ -1367,8 +1376,7 @@ def _slice_lower(ctx, x, *, start_indices, limit_indices, strides):
13671376
mlir.register_lowering(slice_p, _slice_lower)
13681377

13691378

1370-
def _dynamic_slice_shape_rule(
1371-
operand, *starts_and_dyn_sizes, slice_sizes):
1379+
def _dynamic_slice_shape_rule(operand, *starts_and_dyn_sizes, slice_sizes):
13721380
start_indices, dyn = util.split_list(starts_and_dyn_sizes, [operand.ndim])
13731381
if operand.ndim != len(start_indices):
13741382
msg = ("dynamic_slice start_indices must have length equal to the number "
@@ -1391,6 +1399,12 @@ def _dynamic_slice_shape_rule(
13911399
f" got indices {start_indices}")
13921400
return tuple(lax._merge_dyn_shape(slice_sizes, dyn))
13931401

1402+
def _dynamic_slice_sharding_rule(operand, *starts_and_dyn_sizes, slice_sizes):
1403+
out_shape = _dynamic_slice_shape_rule(
1404+
operand, *starts_and_dyn_sizes, slice_sizes=slice_sizes)
1405+
return _get_sharding_for_varying_out_shape(out_shape, operand, 'dynamic_slice')
1406+
1407+
13941408
def _dynamic_slice_dtype_rule(operand, *starts_and_dyn_sizes, slice_sizes):
13951409
start_indices, dyn = util.split_list(starts_and_dyn_sizes, [operand.ndim])
13961410
if any(i.dtype != start_indices[0].dtype or
@@ -1494,7 +1508,8 @@ def _dynamic_slice_padding_rule(in_avals, out_avals, x, *starts_and_dyn,
14941508

14951509
dynamic_slice_p = standard_primitive(
14961510
_dynamic_slice_shape_rule, _dynamic_slice_dtype_rule, 'dynamic_slice',
1497-
weak_type_rule=_argnum_weak_type(0))
1511+
weak_type_rule=_argnum_weak_type(0),
1512+
sharding_rule=_dynamic_slice_sharding_rule)
14981513
ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp
14991514
ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule
15001515
batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule
@@ -1508,7 +1523,10 @@ def _dynamic_slice_lower(ctx, x, *starts_and_dyn_sizes, slice_sizes):
15081523
aval_out, = ctx.avals_out
15091524
if dyn:
15101525
aval_out = aval_out.update(shape=lax._merge_dyn_shape(slice_sizes, dyn))
1511-
return [mlir.dynamic_slice(ctx, aval_out, x, start_indices=start_indices)]
1526+
out = mlir.dynamic_slice(ctx, aval_out, x, start_indices=start_indices)
1527+
if config.sharding_in_types.value:
1528+
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
1529+
return [out]
15121530

15131531
mlir.register_lowering(dynamic_slice_p, _dynamic_slice_lower)
15141532

@@ -1539,6 +1557,14 @@ def _dynamic_update_slice_shape_rule(operand, update, *start_indices):
15391557
f"scalars, got indices {start_indices}")
15401558
return operand.shape
15411559

1560+
def _dynamic_update_slice_sharding_rule(operand, update, *start_indices):
1561+
if operand.sharding != update.sharding:
1562+
raise TypeError(
1563+
"dynamic_update_slice update sharding must be equal to operand"
1564+
f" sharding, got update sharding {update.sharding} for operand sharding"
1565+
f" {operand.sharding}.")
1566+
return operand.sharding
1567+
15421568
def _dynamic_update_slice_dtype_rule(operand, update, *start_indices):
15431569
lax.check_same_dtypes("dynamic_update_slice", operand, update)
15441570
if any(i.dtype != start_indices[0].dtype or
@@ -1604,7 +1630,7 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims):
16041630

16051631
dynamic_update_slice_p = standard_primitive(
16061632
_dynamic_update_slice_shape_rule, _dynamic_update_slice_dtype_rule,
1607-
'dynamic_update_slice')
1633+
'dynamic_update_slice', sharding_rule=_dynamic_update_slice_sharding_rule)
16081634
ad.primitive_jvps[dynamic_update_slice_p] = _dynamic_update_slice_jvp
16091635
ad.primitive_transposes[dynamic_update_slice_p] = \
16101636
_dynamic_update_slice_transpose_rule
@@ -1613,8 +1639,11 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims):
16131639

16141640
def _dynamic_update_slice_lower(ctx, x, update, *start_indices):
16151641
aval_out, = ctx.avals_out
1616-
return [mlir.dynamic_update_slice(ctx, aval_out, x, update,
1617-
start_indices=start_indices)]
1642+
out = mlir.dynamic_update_slice(ctx, aval_out, x, update,
1643+
start_indices=start_indices)
1644+
if config.sharding_in_types.value:
1645+
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
1646+
return [out]
16181647

16191648
mlir.register_lowering(dynamic_update_slice_p, _dynamic_update_slice_lower)
16201649

jax/_src/pjit.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -185,16 +185,19 @@ def _python_pjit_helper(fun, jit_info, *args, **kwargs):
185185
args_flat = [*init_states, *args_flat]
186186

187187
try:
188-
if (core.trace_state_clean() and
189-
not config.debug_key_reuse.value and
190-
not config.data_dependent_tracing_fallback.value):
191-
args_flat = map(core.full_lower, args_flat)
192-
core.check_eval_args(args_flat)
193-
out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params)
194-
else:
195-
out_flat = pjit_p.bind(*args_flat, **p.params)
196-
compiled = None
197-
profiler = None
188+
# TODO(yashkatariya): Maybe thread this into pjit params like resource_env
189+
# and set the context manager down the stack?
190+
with p.abstract_mesh:
191+
if (core.trace_state_clean() and
192+
not config.debug_key_reuse.value and
193+
not config.data_dependent_tracing_fallback.value):
194+
args_flat = map(core.full_lower, args_flat)
195+
core.check_eval_args(args_flat)
196+
out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params)
197+
else:
198+
out_flat = pjit_p.bind(*args_flat, **p.params)
199+
compiled = None
200+
profiler = None
198201
except pxla.DeviceAssignmentMismatchError as e:
199202
fails, = e.args
200203
api_name = 'jit' if p.params['resource_env'] is None else 'pjit'
@@ -330,9 +333,10 @@ def cache_miss(*args, **kwargs):
330333
if config.no_tracing.value:
331334
raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for "
332335
"`jit`, but 'no_tracing' is set")
333-
outs, out_flat, out_tree, args_flat, jaxpr, \
334-
attrs_tracked, executable, pgle_profiler = _python_pjit_helper(
335-
fun, jit_info, *args, **kwargs)
336+
337+
(outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked, executable,
338+
pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
339+
336340
maybe_fastpath_data = _get_fastpath_data(
337341
executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects,
338342
jaxpr.consts, jit_info.abstracted_axes,
@@ -495,10 +499,10 @@ def trace(*args, **kwargs) -> stages.Traced:
495499
donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d)
496500
args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums)
497501
lower_callable = partial(_resolve_and_lower, args_flat, **p.params,
498-
pgle_profiler=None)
502+
pgle_profiler=None)
499503
return stages.Traced(
500504
p.params['jaxpr'], args_info, p.params["name"], p.out_tree,
501-
lower_callable, args_flat, p.arg_names, p.num_consts)
505+
lower_callable, p.abstract_mesh, args_flat, p.arg_names, p.num_consts)
502506

503507
wrapped = _cpp_pjit(fun, jit_info)
504508
wrapped.lower = lower
@@ -534,6 +538,7 @@ class PjitParams(NamedTuple):
534538
arg_names: tuple[str, ...] | None
535539
num_consts: int
536540
attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]
541+
abstract_mesh: AbstractMesh
537542

538543

539544
def _infer_params_impl(
@@ -639,7 +644,9 @@ def _infer_params_impl(
639644

640645
attr_token = _attr_token(flat_fun, in_type)
641646

642-
abstract_mesh = get_abstract_mesh(in_type)
647+
abstract_mesh = (
648+
get_abstract_mesh(in_type) if mesh_lib.mesh_context.mesh is None
649+
else mesh_lib.mesh_context.mesh)
643650
with abstract_mesh:
644651
jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
645652
flat_fun, in_type, attr_token, dbg,
@@ -684,7 +691,7 @@ def _infer_params_impl(
684691
)
685692
return PjitParams(consts, params, in_avals, in_tree, out_tree(),
686693
donated_invars, dbg.arg_names if dbg else None, len(consts),
687-
attrs_tracked), args_flat
694+
attrs_tracked, abstract_mesh), args_flat
688695

689696

690697
def get_abstract_mesh(in_avals):

jax/_src/sharding_impls.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,11 @@ def is_fully_replicated(self) -> bool:
363363
def with_memory_kind(self, kind: str) -> NamedSharding:
364364
return NamedSharding(self.mesh, self.spec, memory_kind=kind)
365365

366+
def with_spec(self, spec: PartitionSpec | Sequence[Any]) -> NamedSharding:
367+
if not isinstance(spec, PartitionSpec):
368+
spec = PartitionSpec(*spec)
369+
return NamedSharding(self.mesh, spec, memory_kind=self.memory_kind)
370+
366371
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
367372
return named_sharding_to_xla_hlo_sharding(self, num_dimensions)
368373

0 commit comments

Comments
 (0)