Skip to content

Commit 12811f0

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
Removed eager_pmap config option
It defaults to True and is not flipped to False by any internal JAX users. PiperOrigin-RevId: 745067361
1 parent c4cc94a commit 12811f0

File tree

4 files changed

+4
-10
lines changed

4 files changed

+4
-10
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
2323
* Removed the `config.jax_data_dependent_tracing_fallback` config option,
2424
which was added temporarily in v0.4.36 to allow users to opt out of the
2525
new "stackless" tracing machinery.
26+
* Removed the `config.jax_eager_pmap` config option.
2627

2728
* Changes
2829
* The minimum CuDNN version is v9.8.

jax/_src/config.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,13 +1514,6 @@ def _update_disable_jit_thread_local(val):
15141514
'compute when encountering OOM errors. However, you are '
15151515
'likely to get better results manually with jax.checkpoint'))
15161516

1517-
# TODO(sharadmv,mattjj): set default to True, then remove
1518-
eager_pmap = bool_state(
1519-
name='jax_eager_pmap',
1520-
default=True,
1521-
upgrade=True,
1522-
help='Enable eager-mode pmap when jax_disable_jit is activated.')
1523-
15241517
no_tracing = bool_state(
15251518
name='jax_no_tracing',
15261519
default=False,

jax/_src/interpreters/pxla.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,8 +338,8 @@ def xla_pmap_impl_lazy(
338338
donated_invars: Sequence[bool],
339339
is_explicit_global_axis_size: bool,
340340
) -> Callable:
341-
if (config.disable_jit.value and config.eager_pmap.value and
342-
not is_explicit_global_axis_size and not any(d for d in donated_invars)):
341+
if (config.disable_jit.value and
342+
not is_explicit_global_axis_size and not any(donated_invars)):
343343
def _emap_apply_fn(*args):
344344
return _emap_impl(fun, *args, backend=backend, axis_name=axis_name,
345345
axis_size=axis_size, global_axis_size=global_axis_size,

tests/pmap_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3189,7 +3189,7 @@ class EagerPmapMixin:
31893189
def setUp(self):
31903190
super().setUp()
31913191
stack = contextlib.ExitStack()
3192-
stack.enter_context(jtu.thread_local_config_context(jax_disable_jit=True, jax_eager_pmap=True))
3192+
stack.enter_context(jtu.thread_local_config_context(jax_disable_jit=True))
31933193
stack.enter_context(jtu.ignore_warning(
31943194
message="Some donated buffers were not usable", category=UserWarning))
31953195
self.addCleanup(stack.close)

0 commit comments

Comments
 (0)