Skip to content

Commit 40fc659

Browse files
yashk2810Google-ML-Automation
authored andcommitted
[sharding_in_types] Make flash_attention forward pass in TPU pallas work nicely with sharding in types. Backward pass is still busted which I will fix in follow up CLs.
Set the abstract mesh context manager at the jit tracing boundary by looking at the mesh on the avals. In the future, this context manager will be user settable too. Abstract mesh context manager is a new context manager with a new context variable and new trace_context entry which governs the cache behavior. If the abstract mesh context manager is not set, the default is `None`. PiperOrigin-RevId: 698493184
1 parent 19b4996 commit 40fc659

File tree

9 files changed

+88
-19
lines changed

9 files changed

+88
-19
lines changed

jax/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,7 @@ pytype_strict_library(
451451
":deprecations",
452452
":dtypes",
453453
":effects",
454+
":mesh",
454455
":pretty_printer",
455456
":source_info_util",
456457
":traceback_util",

jax/_src/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,9 @@ def trace_context():
209209
Values included in this set should also most likely be included in
210210
the C++ JIT state, which is handled separately.
211211
"""
212-
return (axis_env_state.value, mesh_context_manager.value, xla_metadata_context_manager.value,
212+
return (axis_env_state.value, mesh_context_manager.value,
213+
xla_metadata_context_manager.value,
214+
abstract_mesh_context_manager.value,
213215
compute_on_context_manager.value, enable_x64.value,
214216
numpy_rank_promotion.value, default_matmul_precision.value,
215217
dynamic_shapes.value,
@@ -969,6 +971,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
969971
trace_state = config_ext.Config(None, include_in_jit_key=True)
970972
axis_env_state = config_ext.Config((), include_in_jit_key=True)
971973
mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
974+
abstract_mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
972975
compute_on_context_manager = config_ext.Config((), include_in_jit_key=True)
973976
xla_metadata_context_manager = config_ext.Config((), include_in_jit_key=True)
974977
else:

jax/_src/core.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from jax._src import config
3939
from jax._src import effects
4040
from jax._src import compute_on
41+
from jax._src import mesh as mesh_lib
4142
from jax._src.errors import (
4243
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
4344
TracerIntegerConversionError, UnexpectedTracerError)
@@ -1596,6 +1597,23 @@ def _invalid_shape_error(shape: Shape, context: str=""):
15961597

15971598
return TypeError(msg)
15981599

1600+
1601+
def get_sharding(sharding, ndim):
1602+
from jax._src.sharding_impls import NamedSharding, PartitionSpec as P # type: ignore
1603+
1604+
if sharding is not None:
1605+
assert len(sharding.spec) == ndim
1606+
return sharding
1607+
1608+
context_mesh = mesh_lib.mesh_context.mesh
1609+
# TODO(yashkatariya): Error out and ask users to set the context mesh in their
1610+
# code.
1611+
if context_mesh is None:
1612+
return None
1613+
assert sharding is None
1614+
return NamedSharding(context_mesh, P(*[None] * ndim))
1615+
1616+
15991617
class ShapedArray(UnshapedArray):
16001618
__slots__ = ['shape', 'sharding'] # inherits slots from parent
16011619
array_abstraction_level = 2
@@ -1605,20 +1623,18 @@ def __init__(self, shape, dtype, weak_type=False, sharding=None):
16051623
self.dtype = _dtype_object(dtype)
16061624
self.weak_type = weak_type
16071625
if config.sharding_in_types.value:
1608-
if sharding is not None:
1609-
assert len(sharding.spec) == len(self.shape)
1610-
self.sharding = sharding
1626+
self.sharding = get_sharding(sharding, len(self.shape))
16111627

1612-
def update(self, shape=None, dtype=None, weak_type=None, sharding=None):
1628+
def update(self, shape=None, dtype=None, weak_type=None, **kwargs):
16131629
if shape is None:
16141630
shape = self.shape
16151631
if dtype is None:
16161632
dtype = self.dtype
16171633
if weak_type is None:
16181634
weak_type = self.weak_type
1619-
if sharding is None:
1620-
sharding = getattr(self, 'sharding', None)
1621-
return ShapedArray(shape, dtype, weak_type, sharding=sharding)
1635+
if 'sharding' not in kwargs:
1636+
kwargs['sharding'] = getattr(self, 'sharding', None)
1637+
return ShapedArray(shape, dtype, weak_type, **kwargs)
16221638

16231639
ndim = property(lambda self: len(self.shape))
16241640
size = property(lambda self:

jax/_src/mesh.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ class AxisTypes(enum.Enum):
107107
User = enum.auto()
108108
Collective = enum.auto()
109109

110+
def __repr__(self):
111+
return self.name
112+
110113
def axis_names_to_types(axis_types) -> dict[str, AxisTypes]:
111114
if axis_types is None:
112115
return {}
@@ -452,18 +455,34 @@ def local_mesh(self):
452455
_raise_value_error("local_mesh")
453456

454457
def __enter__(self):
455-
raise RuntimeError("AbstractMesh is not a context manager")
458+
mesh_context.stack.append(self)
459+
mesh_context.mesh = self
460+
jax_config.abstract_mesh_context_manager.set_local(
461+
tuple(m for m in mesh_context.stack if m is not None))
462+
return self
456463

457464
def __exit__(self, exc_type, exc_value, traceback):
458-
raise RuntimeError("AbstractMesh is not a context manager")
465+
mesh_context.stack.pop()
466+
mesh_context.mesh = mesh_context.stack[-1]
467+
jax_config.abstract_mesh_context_manager.set_local(
468+
tuple(m for m in mesh_context.stack if m is not None))
469+
return False
459470

460471
@staticmethod
461472
def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh):
462-
jax_config.mesh_context_manager.set_local(mesh)
473+
jax_config.abstract_mesh_context_manager.set_local(mesh)
463474
return
464475

465476

466477
# Create this indirection because pytype fails to recognize a property if a
467478
# property raises an exception unconditionally. Remove this once that is fixed.
468479
def _raise_value_error(name):
469480
raise ValueError(f"AbstractMesh does not implement {name}")
481+
482+
483+
class MeshContext(threading.local):
484+
def __init__(self):
485+
self.stack = [None]
486+
self.mesh = self.stack[-1]
487+
488+
mesh_context = MeshContext()

jax/_src/pallas/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,7 @@ def get_grid_mapping(
873873
)
874874
# The inputs for the index maps
875875
index_map_avals = (
876-
(index_map_grid_aval,) * len(grid_spec.grid))
876+
(index_map_grid_aval.update(sharding=None),) * len(grid_spec.grid))
877877
index_map_tree = tree_util.tree_structure((index_map_avals, {}))
878878

879879
num_scalar_prefetch: int = getattr(grid_spec, "num_scalar_prefetch", 0)

jax/_src/pallas/mosaic/lowering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1380,7 +1380,7 @@ def _masked_swap_lowering_rule(
13801380
1 if b is pallas_core.mapped else next(mem_slice_shape_iter)
13811381
for b in ref_block_shape
13821382
]
1383-
mem_aval = aval_out.update(shape=tuple(mem_slice_shape))
1383+
mem_aval = aval_out.update(shape=tuple(mem_slice_shape), sharding=None)
13841384
mem_aval_vec_type = ir.VectorType.get(mem_aval.shape,
13851385
_dtype_to_ir_type(mem_aval.dtype, is_kernel_boundary=True))
13861386
if need_stride:

jax/_src/pjit.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from collections import defaultdict
1818
from collections.abc import Callable, Sequence, Iterable
19+
import contextlib
1920
import dataclasses
2021
from functools import partial
2122
import inspect
@@ -637,10 +638,13 @@ def _infer_params_impl(
637638
in_avals, in_tree, dbg, device_or_backend_set, have_kwargs)
638639

639640
attr_token = _attr_token(flat_fun, in_type)
640-
jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
641-
flat_fun, in_type, attr_token, dbg,
642-
HashableFunction(res_paths, closure=()),
643-
IgnoreKey(ji.inline))
641+
642+
abstract_mesh = get_abstract_mesh(in_type)
643+
with abstract_mesh:
644+
jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
645+
flat_fun, in_type, attr_token, dbg,
646+
HashableFunction(res_paths, closure=()),
647+
IgnoreKey(ji.inline))
644648
_attr_update(flat_fun, in_type, attr_token, attrs_tracked)
645649

646650
out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
@@ -683,6 +687,26 @@ def _infer_params_impl(
683687
attrs_tracked), args_flat
684688

685689

690+
def get_abstract_mesh(in_avals):
691+
if not config.sharding_in_types.value:
692+
return contextlib.nullcontext()
693+
m = None
694+
for a in in_avals:
695+
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
696+
if a.sharding is None: # type: ignore
697+
continue
698+
if m is not None and m != a.sharding.mesh:
699+
raise ValueError(
700+
f'Mesh for all inputs should be equal. Got one mesh: {m} and'
701+
f' another mesh: {a.sharding.mesh}')
702+
m = a.sharding.mesh # type: ignore
703+
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
704+
if m is None:
705+
return contextlib.nullcontext()
706+
assert m is not None
707+
return m
708+
709+
686710
class InferParamsCacheEntry:
687711
"""Mutable value object for _infer_params_cached."""
688712
__slots__ = ['pjit_params']

jax/_src/state/primitives.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,10 @@ def _get_abstract_eval(ref_aval: AbstractRef, *args,
214214
if isinstance(ref_aval.inner_aval, core.ShapedArray):
215215
out_shape = _shape_after_transforming(ref_aval.shape, transforms)
216216
out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms)
217-
out_aval = ref_aval.inner_aval.update(shape=out_shape, dtype=out_dtype)
217+
# TODO(yashkatariya): Transform the sharding too instead of setting it to
218+
# None.
219+
out_aval = ref_aval.inner_aval.update(shape=out_shape, dtype=out_dtype,
220+
sharding=None)
218221
else:
219222
if transforms:
220223
raise ValueError("Cannot index non-shaped array with nontrivial indices.")

jax/experimental/shard_map.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,8 @@ def _shard_map_staging(
483483
in_tracers = map(trace.to_jaxpr_tracer, in_tracers)
484484
in_avals = [t.aval for t in in_tracers]
485485
in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals)
486-
with core.extend_axis_env_nd(list(mesh.shape.items())):
486+
with (core.extend_axis_env_nd(list(mesh.shape.items())),
487+
pjit.get_abstract_mesh(in_avals_)):
487488
jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_)
488489
_check_names(out_names_thunk(), out_avals_)
489490
if check_rep:
@@ -547,6 +548,8 @@ def _unshard_shaped_array(mesh: Mesh, names: AxisNames,
547548
assert isinstance(aval, core.ShapedArray)
548549
new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
549550
for i, sz in enumerate(aval.shape))
551+
# TODO(yashkatariya): Reset the mesh properly based on the input avals if the
552+
# mesh of shard_map specifies collective axes.
550553
if config.sharding_in_types.value:
551554
spec = _names_to_pspec(names)._normalized_spec(aval.ndim)
552555
new_sharding = NamedSharding(AbstractMesh(mesh.shape_tuple), spec)

0 commit comments

Comments
 (0)