|
61 | 61 | from jax._src.lib.mlir.dialects import func as func_dialect |
62 | 62 | from jax._src.lib import jax_jit |
63 | 63 | from jax._src.lib import xla_client as xc |
64 | | -from jax._src import sharding |
65 | 64 | from jax._src.mesh import AbstractMesh |
| 65 | +from jax._src.sharding import Sharding |
66 | 66 | from jax._src.sharding_impls import ( |
67 | 67 | NamedSharding, GSPMDSharding, |
68 | 68 | SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue, |
|
73 | 73 | from jax._src.tree_util import ( |
74 | 74 | tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, tree_leaves, |
75 | 75 | treedef_children, broadcast_prefix, all_leaves, prefix_errors, keystr, |
76 | | - PyTreeDef, none_leaf_registry as none_lr) |
| 76 | + PyTreeDef, none_leaf_registry as none_lr, tree_map) |
77 | 77 | from jax._src.util import ( |
78 | 78 | HashableFunction, safe_map, safe_zip, wraps, |
79 | 79 | distributed_debug_log, split_list, weakref_lru_cache, |
@@ -1027,7 +1027,7 @@ def hashable_pytree(pytree): |
1027 | 1027 | def _create_sharding_for_array(mesh, x, name, api_name): |
1028 | 1028 | if x is None and (mesh is None or mesh.empty): |
1029 | 1029 | return UNSPECIFIED |
1030 | | - if isinstance(x, (AUTO, UnspecifiedValue, sharding.Sharding)): |
| 1030 | + if isinstance(x, (AUTO, UnspecifiedValue, Sharding)): |
1031 | 1031 | return x |
1032 | 1032 | if mesh is None: |
1033 | 1033 | msg = ('jax.jit only supports `Sharding`s being passed to' |
@@ -1339,7 +1339,7 @@ def _check_and_canonicalize_out_shardings( |
1339 | 1339 | out_shardings_treedef, out_shardings_leaves, out_layouts_treedef, |
1340 | 1340 | out_layouts_leaves, out_tree, out_avals, debug_info, device_or_backend_set): |
1341 | 1341 | orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves) |
1342 | | - if isinstance(orig_out_shardings, (UnspecifiedValue, sharding.Sharding)): |
| 1342 | + if isinstance(orig_out_shardings, (UnspecifiedValue, Sharding)): |
1343 | 1343 | out_shardings_flat = (orig_out_shardings,) * len(out_avals) |
1344 | 1344 | else: |
1345 | 1345 | out_shardings_flat = flatten_axis_resources( |
@@ -1571,7 +1571,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] |
1571 | 1571 | else: |
1572 | 1572 | resolved_in_shardings.append(arg_s) |
1573 | 1573 | else: |
1574 | | - assert isinstance(arg_s, sharding.Sharding) |
| 1574 | + assert isinstance(arg_s, Sharding) |
1575 | 1575 | if dispatch.is_single_device_sharding(arg_s): |
1576 | 1576 | resolved_in_shardings.append(UNSPECIFIED) |
1577 | 1577 | else: |
@@ -1903,7 +1903,7 @@ def _pjit_typecheck(ctx_factory, *in_atoms, jaxpr, **params): |
1903 | 1903 | core.custom_typechecks[pjit_p] = _pjit_typecheck |
1904 | 1904 |
|
1905 | 1905 |
|
1906 | | -def _pjit_abstract_eval(*args, jaxpr, **_): |
| 1906 | +def _pjit_abstract_eval(*args, jaxpr, out_shardings, **_): |
1907 | 1907 | return jaxpr.out_avals, jaxpr.effects |
1908 | 1908 | pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval) |
1909 | 1909 |
|
@@ -2016,7 +2016,7 @@ def _pjit_batcher(axis_data, vals_in, dims_in, |
2016 | 2016 | batching.ragged_prop_rules[pjit_p] = batching.ragged_mask_no_op_rule |
2017 | 2017 |
|
2018 | 2018 | def _pjit_batcher_for_sharding( |
2019 | | - s: sharding.Sharding | UnspecifiedValue, |
| 2019 | + s: Sharding | UnspecifiedValue, |
2020 | 2020 | dim: int, spmd_axis_name: tuple[str, ...] | None, mesh, ndim: int): |
2021 | 2021 | if isinstance(s, UnspecifiedValue): |
2022 | 2022 | return s |
@@ -2673,6 +2673,74 @@ def _sharding_constraint_batcher( |
2673 | 2673 | batching.fancy_primitive_batchers[sharding_constraint_p] = _sharding_constraint_batcher |
2674 | 2674 | batching.skippable_batchers[sharding_constraint_p] = lambda _: () |
2675 | 2675 |
|
| 2676 | +# -------------------- sharding_cast --------------------------- |
| 2677 | + |
| 2678 | +def _check_mesh_shape_same(src_sharding, dst_sharding, aval): |
| 2679 | + if src_sharding.mesh.shape_tuple != dst_sharding.mesh.shape_tuple: |
| 2680 | + raise ValueError( |
| 2681 | + f'Mesh shape of the input {src_sharding.mesh.shape_tuple} does not' |
| 2682 | + ' match the mesh shape of the target sharding' |
| 2683 | + f' {dst_sharding.mesh.shape_tuple} for shape {aval.str_short()}') |
| 2684 | + |
| 2685 | +def sharding_cast(xs, shardings): |
| 2686 | + if isinstance(shardings, NamedSharding): |
| 2687 | + return tree_map(lambda x: sharding_cast_p.bind( |
| 2688 | + x, src_sharding=x.sharding, dst_sharding=shardings), xs) |
| 2689 | + |
| 2690 | + x_flat, treedef = tree_flatten(xs) |
| 2691 | + shardings_flat = flatten_axes("sharding_cast shardings", treedef, shardings) |
| 2692 | + out_flat = [sharding_cast_p.bind(x, src_sharding=x.sharding, dst_sharding=s) |
| 2693 | + for x, s in safe_zip(x_flat, shardings_flat)] |
| 2694 | + return tree_unflatten(treedef, out_flat) |
| 2695 | + |
| 2696 | +sharding_cast_p = core.Primitive('sharding_cast') |
| 2697 | +def _sharding_cast_abstract_eval(aval, src_sharding, dst_sharding): |
| 2698 | + _check_mesh_shape_same(src_sharding, dst_sharding, aval) |
| 2699 | + return aval.update(sharding=dst_sharding) |
| 2700 | +sharding_cast_p.def_abstract_eval(_sharding_cast_abstract_eval) |
| 2701 | + |
| 2702 | +def _sharding_cast_impl(x, src_sharding, dst_sharding): |
| 2703 | + aval = shaped_abstractify(x) |
| 2704 | + _check_mesh_shape_same(x.sharding, dst_sharding, aval) |
| 2705 | + new_mesh = x.sharding.mesh.with_axis_types(dst_sharding.mesh.axis_types) |
| 2706 | + concrete_dst_sharding = NamedSharding(new_mesh, dst_sharding.spec) |
| 2707 | + # TODO(yashkatariya): Replace this with `dispatch.apply_primitive(...)` |
| 2708 | + return api.jit(_identity_fn, out_shardings=concrete_dst_sharding)(x) |
| 2709 | +sharding_cast_p.def_impl(_sharding_cast_impl) |
| 2710 | + |
| 2711 | +def _sharding_cast_transpose_rule(ct, _, src_sharding, dst_sharding): |
| 2712 | + return [sharding_cast_p.bind(ct, src_sharding=dst_sharding, |
| 2713 | + dst_sharding=src_sharding)] |
| 2714 | +ad.deflinear2(sharding_cast_p, _sharding_cast_transpose_rule) |
| 2715 | + |
| 2716 | +def _sharding_cast_hlo_lowering(ctx, x_node, *, src_sharding, dst_sharding): |
| 2717 | + aval, = ctx.avals_in |
| 2718 | + aval_out, = ctx.avals_out |
| 2719 | + proto = dst_sharding._to_xla_hlo_sharding(aval.ndim).to_proto() |
| 2720 | + return [mlir.lower_sharding_under_shit(ctx, x_node, aval_out, proto)] |
| 2721 | +mlir.register_lowering(sharding_cast_p, _sharding_cast_hlo_lowering) |
| 2722 | + |
| 2723 | +# TODO(yashkatariya): Comment this in after vmap ShiT tests are added. |
| 2724 | +# def _sharding_cast_batcher(axis_data, vals_in, dims_in, src_sharding, |
| 2725 | +# dst_sharding): |
| 2726 | +# if axis_data.spmd_name is not None: |
| 2727 | +# used = {n for ns in dst_sharding.spec |
| 2728 | +# for n in (ns if isinstance(ns, tuple) else (ns,))} |
| 2729 | +# if set(axis_data.spmd_name) & used: |
| 2730 | +# raise ValueError( |
| 2731 | +# f'vmap spmd_axis_name {axis_data.spmd_name} cannot ' |
| 2732 | +# f'appear in sharding_cast spec, but got spec {dst_sharding.spec}') |
| 2733 | +# x, = vals_in |
| 2734 | +# d, = dims_in |
| 2735 | + |
| 2736 | +# val = None if axis_data.spmd_name is None else axis_data.spmd_name |
| 2737 | +# new_spec = PartitionSpec(*util.tuple_insert(dst_sharding.spec, d, val)) |
| 2738 | +# vmapped_dst_sharding = NamedSharding(dst_sharding.mesh, new_spec) |
| 2739 | +# y = sharding_cast_p.bind(x, src_sharding=src_sharding, |
| 2740 | +# dst_sharding=vmapped_dst_sharding) |
| 2741 | +# return y, d |
| 2742 | +# batching.fancy_primitive_batchers[sharding_cast_p] = _sharding_cast_batcher |
| 2743 | +# batching.skippable_batchers[sharding_cast_p] = lambda _: () |
2676 | 2744 |
|
2677 | 2745 | # -------------------- helpers -------------------- |
2678 | 2746 |
|
|
0 commit comments