Skip to content

Commit 51b9fe3

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
[JAX] Add a new jax_num_cpu_devices flag that allows the user to specify the number of CPU directly.
This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS. In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads. PiperOrigin-RevId: 713272197
1 parent f96339b commit 51b9fe3

18 files changed

+48
-172
lines changed

jax/_src/test_util.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -498,29 +498,20 @@ def device_supports_buffer_donation():
498498
)
499499

500500

501-
@contextmanager
502-
def set_host_platform_device_count(nr_devices: int):
503-
"""Context manager to set host platform device count if not specified by user.
501+
def request_cpu_devices(nr_devices: int):
502+
"""Requests at least `nr_devices` CPU devices.
503+
504+
request_cpu_devices should be called at the top-level of a test module before
505+
main() runs.
504506
505-
This should only be used by tests at the top level in setUpModule(); it will
506-
not work correctly if applied to individual test cases.
507+
It is not guaranteed that the number of CPU devices will be exactly
508+
`nr_devices`: it may be more or less, depending on how exactly the test is
509+
invoked. Test cases that require a specific number of devices should skip
510+
themselves if that number is not met.
507511
"""
508-
prev_xla_flags = os.getenv("XLA_FLAGS")
509-
flags_str = prev_xla_flags or ""
510-
# Don't override user-specified device count, or other XLA flags.
511-
if "xla_force_host_platform_device_count" not in flags_str:
512-
os.environ["XLA_FLAGS"] = (flags_str +
513-
f" --xla_force_host_platform_device_count={nr_devices}")
514-
# Clear any cached backends so new CPU backend will pick up the env var.
515-
xla_bridge.get_backend.cache_clear()
516-
try:
517-
yield
518-
finally:
519-
if prev_xla_flags is None:
520-
del os.environ["XLA_FLAGS"]
521-
else:
522-
os.environ["XLA_FLAGS"] = prev_xla_flags
512+
if xla_bridge.NUM_CPU_DEVICES.value < nr_devices:
523513
xla_bridge.get_backend.cache_clear()
514+
config.update("jax_num_cpu_devices", nr_devices)
524515

525516

526517
def skip_on_flag(flag_name, skip_value):

jax/_src/xla_bridge.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,14 @@
122122
"inline without async dispatch.",
123123
)
124124

125+
NUM_CPU_DEVICES = config.int_flag(
126+
name="jax_num_cpu_devices",
127+
default=-1,
128+
help="Number of CPU devices to use. If not provided, the value of "
129+
"the XLA flag --xla_force_host_platform_device_count is used."
130+
" Must be set before JAX is initialized.",
131+
)
132+
125133

126134
# Warn the user if they call fork(), because it's not going to go well for them.
127135
def _at_fork():
@@ -249,8 +257,8 @@ def make_cpu_client(
249257
if collectives is None:
250258
collectives_impl = CPU_COLLECTIVES_IMPLEMENTATION.value
251259
if _CPU_ENABLE_GLOO_COLLECTIVES.value:
252-
collectives_impl = 'gloo'
253-
warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is '
260+
collectives_impl = 'gloo'
261+
warnings.warn('Setting `jax_cpu_enable_gloo_collectives` is '
254262
'deprecated. Please use `jax.config.update('
255263
'"jax_cpu_collectives_implementation", "gloo")` instead.',
256264
DeprecationWarning,
@@ -268,12 +276,22 @@ def make_cpu_client(
268276
f"{collectives_impl}. Available implementations are "
269277
f"{CPU_COLLECTIVES_IMPLEMENTATIONS}.")
270278

279+
num_devices = NUM_CPU_DEVICES.value if NUM_CPU_DEVICES.value >= 0 else None
280+
if xla_client._version < 303 and num_devices is not None:
281+
xla_flags = os.getenv("XLA_FLAGS") or ""
282+
os.environ["XLA_FLAGS"] = (
283+
f"{xla_flags} --xla_force_host_platform_device_count={num_devices}"
284+
)
285+
num_devices = None
286+
# TODO(phawkins): pass num_devices directly when version 303 is the minimum.
287+
kwargs = {} if num_devices is None else {"num_devices": num_devices}
271288
return xla_client.make_cpu_client(
272289
asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value,
273290
distributed_client=distributed.global_state.client,
274291
node_id=distributed.global_state.process_id,
275292
num_nodes=distributed.global_state.num_processes,
276293
collectives=collectives,
294+
**kwargs,
277295
)
278296

279297

jax/experimental/array_serialization/serialization_test.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
"""Tests for serialization and deserialization of GDA."""
1515

1616
import asyncio
17-
import contextlib
1817
import math
1918
from functools import partial
2019
import os
@@ -36,13 +35,7 @@
3635
import tensorstore as ts
3736

3837
jax.config.parse_flags_with_absl()
39-
_exit_stack = contextlib.ExitStack()
40-
41-
def setUpModule():
42-
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
43-
44-
def tearDownModule():
45-
_exit_stack.close()
38+
jtu.request_cpu_devices(8)
4639

4740

4841
class CheckpointTest(jtu.JaxTestCase):

jax/experimental/jax2tf/tests/sharding_test.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
2020
"""
2121
from collections.abc import Sequence
22-
import contextlib
2322
from functools import partial
2423
import logging
2524
import re
@@ -47,16 +46,14 @@
4746
import tensorflow as tf
4847

4948
config.parse_flags_with_absl()
49+
jtu.request_cpu_devices(8)
5050

5151
# Must come after initializing the flags
5252
from jax.experimental.jax2tf.tests import tf_test_util
5353

54-
_exit_stack = contextlib.ExitStack()
5554
topology = None
5655

5756
def setUpModule():
58-
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
59-
6057
global topology
6158
if jtu.test_device_matches(["tpu"]):
6259
with jtu.ignore_warning(message="the imp module is deprecated"):
@@ -67,8 +64,6 @@ def setUpModule():
6764
else:
6865
topology = None
6966

70-
def tearDownModule():
71-
_exit_stack.close()
7267

7368

7469
class ShardingTest(tf_test_util.JaxToTfTestCase):

tests/array_test.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,20 +43,12 @@
4343
from jax._src import prng
4444

4545
jax.config.parse_flags_with_absl()
46+
jtu.request_cpu_devices(8)
4647

4748
with contextlib.suppress(ImportError):
4849
import pytest
4950
pytestmark = pytest.mark.multiaccelerator
5051

51-
# Run all tests with 8 CPU devices.
52-
_exit_stack = contextlib.ExitStack()
53-
54-
def setUpModule():
55-
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
56-
57-
def tearDownModule():
58-
_exit_stack.close()
59-
6052

6153
def create_array(shape, sharding, global_data=None):
6254
if global_data is None:

tests/colocated_python_test.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import contextlib
1615
import threading
1716
import time
1817
from typing import Sequence
@@ -29,6 +28,7 @@
2928
import numpy as np
3029

3130
config.parse_flags_with_absl()
31+
jtu.request_cpu_devices(8)
3232

3333

3434
def _colocated_cpu_devices(
@@ -53,18 +53,6 @@ def _colocated_cpu_devices(
5353
_count_colocated_python_specialization_cache_miss = jtu.count_events(
5454
"colocated_python_func._get_specialized_func")
5555

56-
_exit_stack = contextlib.ExitStack()
57-
58-
59-
def setUpModule():
60-
# TODO(hyeontaek): Remove provisioning "cpu" backend devices once PjRt-IFRT
61-
# prepares CPU devices by its own.
62-
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
63-
64-
65-
def tearDownModule():
66-
_exit_stack.close()
67-
6856

6957
class ColocatedPythonTest(jtu.JaxTestCase):
7058

tests/debugger_test.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
from collections.abc import Sequence
16-
import contextlib
1716
import io
1817
import re
1918
import textwrap
@@ -29,6 +28,7 @@
2928
import numpy as np
3029

3130
jax.config.parse_flags_with_absl()
31+
jtu.request_cpu_devices(2)
3232

3333
def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringIO]:
3434
fake_stdin = io.StringIO()
@@ -41,14 +41,6 @@ def make_fake_stdin_stdout(commands: Sequence[str]) -> tuple[IO[str], io.StringI
4141
def _format_multiline(text):
4242
return textwrap.dedent(text).lstrip()
4343

44-
_exit_stack = contextlib.ExitStack()
45-
46-
def setUpModule():
47-
_exit_stack.enter_context(jtu.set_host_platform_device_count(2))
48-
49-
def tearDownModule():
50-
_exit_stack.close()
51-
5244
foo = 2
5345

5446
class CliDebuggerTest(jtu.JaxTestCase):

tests/debugging_primitives_test.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import collections
15-
import contextlib
1615
import functools
1716
import textwrap
1817
import unittest
@@ -35,19 +34,13 @@
3534
rich = None
3635

3736
jax.config.parse_flags_with_absl()
37+
jtu.request_cpu_devices(2)
3838

3939
debug_print = debugging.debug_print
4040

4141
def _format_multiline(text):
4242
return textwrap.dedent(text).lstrip()
4343

44-
_exit_stack = contextlib.ExitStack()
45-
46-
def setUpModule():
47-
_exit_stack.enter_context(jtu.set_host_platform_device_count(2))
48-
49-
def tearDownModule():
50-
_exit_stack.close()
5144

5245
class DummyDevice:
5346
def __init__(self, platform, id):

tests/export_test.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,8 @@
5656
CAN_SERIALIZE = False
5757

5858
config.parse_flags_with_absl()
59+
jtu.request_cpu_devices(8)
5960

60-
_exit_stack = contextlib.ExitStack()
61-
62-
def setUpModule():
63-
_exit_stack.enter_context(jtu.set_host_platform_device_count(2))
64-
65-
def tearDownModule():
66-
_exit_stack.close()
6761

6862
### Setup for testing lowering with effects
6963
@dataclasses.dataclass(frozen=True)

tests/jaxpr_effects_test.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import contextlib
14+
1515
import threading
1616
import unittest
1717

@@ -34,6 +34,7 @@
3434
import numpy as np
3535

3636
config.parse_flags_with_absl()
37+
jtu.request_cpu_devices(2)
3738

3839
effect_p = core.Primitive('effect')
3940
effect_p.multiple_results = True
@@ -132,15 +133,6 @@ def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out
132133
mlir.register_lowering(callback_p, callback_effect_lowering)
133134

134135

135-
_exit_stack = contextlib.ExitStack()
136-
137-
def setUpModule():
138-
_exit_stack.enter_context(jtu.set_host_platform_device_count(2))
139-
140-
def tearDownModule():
141-
_exit_stack.close()
142-
143-
144136
class JaxprEffectsTest(jtu.JaxTestCase):
145137

146138
def test_trivial_jaxpr_has_no_effects(self):

0 commit comments

Comments
 (0)