Skip to content

Commit 76dec38

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Under pjit the with mesh: context will use use_mesh(mesh): jit instead of tracking separately using resource_env.
This would also make it easier to deprecate the `with mesh: pjit` path in the future from user code since the new path would be completely tested. This will also allow us to remove `resource_env` from JAX and the internal API access of `resource_env.physical_mesh` spread throughout codebases internally and externally. PiperOrigin-RevId: 735602187
1 parent 02505fa commit 76dec38

File tree

7 files changed

+65
-108
lines changed

7 files changed

+65
-108
lines changed

jax/_src/checkify.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -901,7 +901,7 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
901901
def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
902902
in_shardings, out_shardings,
903903
in_layouts, out_layouts,
904-
resource_env, donated_invars, name, inline, keep_unused,
904+
donated_invars, ctx_mesh, name, inline, keep_unused,
905905
compiler_options_kvs):
906906
# jaxpr to checked_jaxpr
907907
err_vals, err_tree = jtu.tree_flatten(error)
@@ -928,8 +928,8 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
928928
out_shardings=new_out_shardings,
929929
in_layouts=new_in_layouts,
930930
out_layouts=new_out_layouts,
931-
resource_env=resource_env,
932931
donated_invars=new_donated_invars,
932+
ctx_mesh=ctx_mesh,
933933
name=name,
934934
inline=inline,
935935
keep_unused=keep_unused,

jax/_src/custom_partitioning.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
181181
closed_jaxpr = jax.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))(
182182
*tiled_args
183183
)
184-
if closed_jaxpr.out_avals != tiled_results:
184+
if ([(o.shape, o.dtype) for o in closed_jaxpr.out_avals] !=
185+
[(t.shape, t.dtype) for t in tiled_results]):
185186
raise ValueError(
186187
"Mismatch in result shapes. %s vs %s"
187188
% (repr(closed_jaxpr.out_avals), repr(tiled_results))

jax/_src/interpreters/pxla.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,7 +1663,7 @@ def __str__(self):
16631663
elif self.name == 'OUT_SHARDING':
16641664
return 'explicit output sharding'
16651665
elif self.name == 'CONTEXT_DEVICES':
1666-
return 'devices'
1666+
return 'context mesh'
16671667
return f'{self.name}'
16681668

16691669

@@ -3060,7 +3060,6 @@ class JitGlobalCppCacheKeys:
30603060
in_layouts_leaves: tuple[Any, ...] | None = None
30613061
out_layouts_treedef: PyTreeDef | None = None
30623062
out_layouts_leaves: tuple[Any, ...] | None = None
3063-
use_resource_env: bool = False
30643063
compiler_options_kvs: tuple[tuple[str, Any], ...] | None = None
30653064

30663065
@functools.cached_property

jax/_src/pjit.py

Lines changed: 56 additions & 73 deletions
Large diffs are not rendered by default.

jax/experimental/jax2tf/jax2tf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3573,8 +3573,8 @@ def _pjit(*args: TfVal,
35733573
in_shardings: Sequence[sharding.Sharding],
35743574
out_shardings: Sequence[sharding.Sharding],
35753575
in_layouts, out_layouts,
3576-
resource_env: mesh.ResourceEnv,
35773576
donated_invars,
3577+
ctx_mesh,
35783578
name: str,
35793579
keep_unused: bool,
35803580
inline: bool,

jax/experimental/sparse/transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,7 @@ def _while_sparse(spenv, *spvalues, cond_jaxpr, cond_nconsts, body_jaxpr, body_n
775775

776776

777777
def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
778-
in_layouts, out_layouts, resource_env, donated_invars, name,
778+
in_layouts, out_layouts, donated_invars, ctx_mesh, name,
779779
keep_unused, inline, compiler_options_kvs):
780780
if any(donated_invars):
781781
raise NotImplementedError("sparse xla_call with donated_invars")
@@ -808,8 +808,8 @@ def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
808808
out_shardings=out_shardings,
809809
in_layouts=in_layouts,
810810
out_layouts=out_layouts,
811-
resource_env=resource_env,
812811
donated_invars=donated_invars,
812+
ctx_mesh=ctx_mesh,
813813
name=name,
814814
keep_unused=keep_unused,
815815
inline=inline,

tests/pjit_test.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,8 +1205,7 @@ def f(x):
12051205
with self.assertRaisesRegex(
12061206
ValueError,
12071207
r"One of with_sharding_constraint.*Sharding "
1208-
r"NamedSharding\(mesh=Mesh\('replica': 1, 'data': 1, 'mdl': 2\), "
1209-
r"spec=PartitionSpec\(None, 'mdl', None, None\).*\) is only "
1208+
r"NamedSharding.*PartitionSpec\(None, 'mdl', None, None\).*\) is only "
12101209
"valid for values of rank at least 4, but was applied to a value of rank 1"):
12111210
pjit_f(jnp.array([1, 2, 3]))
12121211

@@ -6873,31 +6872,6 @@ def test_wsc_error(self, mesh):
68736872
' axis_types are `Auto`'):
68746873
NamedSharding(mesh, P(P.UNCONSTRAINED))
68756874

6876-
def test_use_mesh_legacy_mesh_ctx_mgr_mix_error(self):
6877-
mesh = jtu.create_mesh((1, 1), ('x', 'y'))
6878-
6879-
with self.assertRaisesRegex(
6880-
ValueError,
6881-
'Using `with mesh:` context manager and `jax.sharding.use_mesh`'
6882-
' together is not allowed'):
6883-
with jax.sharding.use_mesh(mesh), mesh:
6884-
jax.jit(lambda x: x)(jnp.arange(8))
6885-
6886-
with self.assertRaisesRegex(
6887-
ValueError,
6888-
'Using `with mesh:` context manager and `jax.sharding.use_mesh`'
6889-
' together is not allowed'):
6890-
with jax.sharding.use_mesh(mesh), mesh:
6891-
jnp.zeros((8, 2), dtype=jnp.int32)
6892-
6893-
x = jnp.arange(8)
6894-
with self.assertRaisesRegex(
6895-
ValueError,
6896-
'Using `with mesh:` context manager and `jax.sharding.use_mesh`'
6897-
' together is not allowed'):
6898-
with jax.sharding.use_mesh(mesh), mesh:
6899-
jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
6900-
69016875
def test_pspec_einsum_no_context_mesh(self):
69026876
mesh = jtu.create_mesh((1, 1), ('x', 'y'),
69036877
axis_types={AxisTypes.Explicit: ('x', 'y')})

0 commit comments

Comments
 (0)