Skip to content

Commit 3fa5572

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Port tests away from setUpClass and setUpModule to setUp alone.
This change prepares for upcoming changes in which we run tests in parallel using threads, which we are doing partially to test free threading but also partially to speed up TPU tests via thread-parallelism. If independent tests run in parallel in no particular order, there's no natural scope around which to call setUpClass or SetUpModule. But for JAX tests this never seems necessary: we can just do the same work in setUp() or do it globally. PiperOrigin-RevId: 713296722
1 parent f1f98af commit 3fa5572

File tree

11 files changed

+56
-71
lines changed

11 files changed

+56
-71
lines changed

jax/_src/test_util.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,10 +1082,8 @@ class JaxTestCase(parameterized.TestCase):
10821082
'jax_legacy_prng_key': 'error',
10831083
}
10841084

1085-
_compilation_cache_exit_stack: ExitStack | None = None
1085+
_context_stack: ExitStack | None = None
10861086

1087-
def tearDown(self) -> None:
1088-
assert core.reset_trace_state()
10891087

10901088
def setUp(self):
10911089
super().setUp()
@@ -1096,11 +1094,12 @@ def setUp(self):
10961094
# b) it returns values in int32 range, which RandomState requires.
10971095
self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode()))
10981096

1099-
@classmethod
1100-
def setUpClass(cls):
1101-
cls._compilation_cache_exit_stack = ExitStack()
1102-
stack = cls._compilation_cache_exit_stack
1103-
stack.enter_context(global_config_context(**cls._default_config))
1097+
# TODO(phawkins): use TestCase.enterContext once Python 3.11 is the minimum
1098+
# version.
1099+
self._context_stack = ExitStack()
1100+
self.addCleanup(self._context_stack.close)
1101+
stack = self._context_stack
1102+
stack.enter_context(global_config_context(**self._default_config))
11041103

11051104
if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
11061105
stack.enter_context(config.enable_compilation_cache(True))
@@ -1109,12 +1108,12 @@ def setUpClass(cls):
11091108
stack.enter_context(config.persistent_cache_min_entry_size_bytes(0))
11101109

11111110
tmp_dir = stack.enter_context(tempfile.TemporaryDirectory())
1112-
compilation_cache.set_cache_dir(tmp_dir)
1113-
stack.callback(lambda: compilation_cache.reset_cache())
1111+
stack.enter_context(config.compilation_cache_dir(tmp_dir))
1112+
stack.callback(compilation_cache.reset_cache)
11141113

1115-
@classmethod
1116-
def tearDownClass(cls):
1117-
cls._compilation_cache_exit_stack.close()
1114+
def tearDown(self) -> None:
1115+
assert core.reset_trace_state()
1116+
super().tearDown()
11181117

11191118
def rng(self):
11201119
return self._rng

jax/experimental/jax2tf/tests/call_tf_test.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -69,25 +69,20 @@ def _named_test(**kwargs):
6969

7070
class CallTfTest(tf_test_util.JaxToTfTestCase):
7171

72-
@classmethod
73-
def setUpClass(cls):
74-
# One TF device of each device_type
75-
cls.tf_devices = []
76-
for tf_device in tf.config.list_logical_devices():
77-
if tf_device.device_type == "TPU_SYSTEM":
78-
continue # A virtual device
79-
if all(tf_device.device_type != d.device_type for d in cls.tf_devices):
80-
cls.tf_devices.append(tf_device)
81-
82-
super().setUpClass()
83-
8472
def setUp(self):
8573
if tf is None:
8674
raise unittest.SkipTest("Test requires tensorflow")
8775
# TODO(b/171320191): this line works around a missing context initialization
8876
# bug in TensorFlow.
8977
_ = tf.add(1, 1)
9078
super().setUp()
79+
# One TF device of each device_type
80+
self.tf_devices = []
81+
for tf_device in tf.config.list_logical_devices():
82+
if tf_device.device_type == "TPU_SYSTEM":
83+
continue # A virtual device
84+
if all(tf_device.device_type != d.device_type for d in self.tf_devices):
85+
self.tf_devices.append(tf_device)
9186
self.warning_ctx = jtu.ignore_warning(
9287
message=(
9388
"(jax2tf.convert with native_serialization=False has been deprecated"
@@ -798,7 +793,7 @@ def f_jax(x):
798793

799794
jax_and_tf_platforms = (
800795
set(jax_platforms) & {d.device_type.lower()
801-
for d in self.__class__.tf_devices})
796+
for d in self.tf_devices})
802797

803798
lowering_platforms = ("tpu", "cpu", "cuda")
804799

@@ -833,7 +828,7 @@ def f_jax(x):
833828
f_jax,
834829
native_serialization=True,
835830
native_serialization_platforms=lowering_platforms))
836-
for tf_device in self.__class__.tf_devices:
831+
for tf_device in self.tf_devices:
837832
with self.subTest(tf_device.device_type):
838833
logging.info(
839834
f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}")

jax/experimental/jax2tf/tests/jax2tf_test.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,22 +50,17 @@
5050

5151
class Jax2TfTest(tf_test_util.JaxToTfTestCase):
5252

53-
@classmethod
54-
def setUpClass(cls):
53+
def setUp(self):
54+
super().setUp()
5555
# One TF device of each device_type
56-
cls.tf_devices = []
56+
self.tf_devices = []
5757
for tf_device in (tf.config.list_logical_devices("TPU") +
5858
tf.config.list_logical_devices("GPU") +
5959
tf.config.list_logical_devices()):
6060
if tf_device.device_type == "TPU_SYSTEM":
6161
continue # A virtual device
62-
if all(tf_device.device_type != d.device_type for d in cls.tf_devices):
63-
cls.tf_devices.append(tf_device)
64-
65-
super().setUpClass()
66-
67-
def setUp(self):
68-
super().setUp()
62+
if all(tf_device.device_type != d.device_type for d in self.tf_devices):
63+
self.tf_devices.append(tf_device)
6964
self.warning_ctx = jtu.ignore_warning(
7065
message="jax2tf.convert with native_serialization=False has been deprecated"
7166
)
@@ -1666,7 +1661,7 @@ def f_jax(x):
16661661
f_jax,
16671662
native_serialization=True,
16681663
native_serialization_platforms=("cpu", "cuda", "tpu"))
1669-
for tf_device in self.__class__.tf_devices:
1664+
for tf_device in self.tf_devices:
16701665
logging.info(
16711666
f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}")
16721667
with tf.device(tf_device):

jax/experimental/jax2tf/tests/sharding_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from typing import Any
2626
import unittest
2727

28+
from absl import app
2829
from absl.testing import absltest
2930

3031
import jax
@@ -53,7 +54,8 @@
5354

5455
topology = None
5556

56-
def setUpModule():
57+
58+
def initialize_tf_tpu():
5759
global topology
5860
if jtu.test_device_matches(["tpu"]):
5961
with jtu.ignore_warning(message="the imp module is deprecated"):
@@ -64,6 +66,7 @@ def setUpModule():
6466
else:
6567
topology = None
6668

69+
app.call_after_init(initialize_tf_tpu)
6770

6871

6972
class ShardingTest(tf_test_util.JaxToTfTestCase):

tests/compilation_cache_test.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,18 +52,12 @@
5252
FAKE_COMPILE_TIME = 10
5353
_counts = Counter() # Map event name to count
5454

55-
56-
def setUpModule():
57-
monitoring.register_event_listener(increment_event_count)
58-
59-
60-
def tearDownModule():
61-
monitoring._unregister_event_listener_by_callback(increment_event_count)
62-
63-
6455
def increment_event_count(event):
6556
_counts[event] += 1
6657

58+
monitoring.register_event_listener(increment_event_count)
59+
60+
6761
def msg_exists_in_logs(msg: str, records: list[logging.LogRecord],
6862
level: int | None = None) -> bool:
6963
return any(msg in record.getMessage() for record in records

tests/dynamic_api_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1487,6 +1487,7 @@ def f(i):
14871487
class JumbleTest(jtu.JaxTestCase):
14881488

14891489
def setUp(self):
1490+
super().setUp()
14901491
if jax.config.x64_enabled: raise unittest.SkipTest()
14911492

14921493
@parameterized.parameters((True,), (False,))

tests/export_harnesses_multi_platform_test.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,21 @@ def make_disjunction_regexp(*parts: str) -> re.Pattern[str]:
4848

4949
class PrimitiveTest(jtu.JaxTestCase):
5050

51-
@classmethod
52-
def setUpClass(cls):
51+
def setUp(self):
52+
super().setUp()
5353
# Pick one device from each available platform
54-
cls.devices = []
55-
cls.platforms = []
54+
self.devices = []
55+
self.platforms = []
5656
for backend in ["cpu", "gpu", "tpu"]:
5757
try:
5858
devices = jax.devices(backend)
5959
except RuntimeError:
6060
devices = []
6161

6262
for d in devices:
63-
if d.platform not in cls.platforms:
64-
cls.platforms.append(d.platform)
65-
cls.devices.append(d)
66-
super().setUpClass()
63+
if d.platform not in self.platforms:
64+
self.platforms.append(d.platform)
65+
self.devices.append(d)
6766

6867
# For each primitive we export for all platforms that are available and
6968
# compare the results of running the exported code and running the native
@@ -128,7 +127,7 @@ def export_and_compare_to_native(
128127
tol: float | None = None):
129128
devices = [
130129
d
131-
for d in self.__class__.devices
130+
for d in self.devices
132131
if d.platform not in unimplemented_platforms
133132
]
134133
logging.info("Using devices %s", [str(d) for d in devices])

tests/export_test.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,17 +159,16 @@ def serde_exported(*fun_args, **fun_kwargs):
159159
@jtu.with_config(jax_export_calling_convention_version=export.maximum_supported_calling_convention_version)
160160
class JaxExportTest(jtu.JaxTestCase):
161161

162-
@classmethod
163-
def setUpClass(cls):
162+
def setUp(self):
163+
super().setUp()
164164
# Find the available platforms
165-
cls.platforms = []
165+
self.platforms = []
166166
for backend in ["cpu", "gpu", "tpu"]:
167167
try:
168168
jax.devices(backend)
169169
except RuntimeError:
170170
continue
171-
cls.platforms.append(backend)
172-
super().setUpClass()
171+
self.platforms.append(backend)
173172

174173
def test_basic_export_only(self):
175174
@jax.jit
@@ -1499,7 +1498,7 @@ def test_multi_platform(self):
14991498
module_str)
15001499

15011500
# Call with argument placed on different plaforms
1502-
for platform in self.__class__.platforms:
1501+
for platform in self.platforms:
15031502
x_device = jax.device_put(x, jax.devices(platform)[0])
15041503
res_exp = exp.call(x_device)
15051504
self.assertAllClose(
@@ -1524,7 +1523,7 @@ def test_multi_platform_nested(self):
15241523
self.assertEqual(1, count_sine)
15251524

15261525
# Call with argument placed on different plaforms
1527-
for platform in self.__class__.platforms:
1526+
for platform in self.platforms:
15281527
if platform == "tpu": continue
15291528
x_device = jax.device_put(x, jax.devices(platform)[0])
15301529
res_exp = exp2.call(x_device)
@@ -1668,7 +1667,7 @@ def f_jax(b): # b: f32[16 // DEVICES, 4]
16681667
exp = get_exported(f_jax, platforms=("cpu", "tpu", "cuda", "rocm"))(a)
16691668

16701669
# Call with argument placed on different plaforms
1671-
for platform in self.__class__.platforms:
1670+
for platform in self.platforms:
16721671
run_devices = jax.devices(platform)[0:len(export_devices)]
16731672
if len(run_devices) != len(export_devices):
16741673
continue

tests/mock_gpu_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@
3232
class MockGPUTest(jtu.JaxTestCase):
3333

3434
def setUp(self):
35+
super().setUp()
3536
if not jtu.test_device_matches(["gpu"]):
3637
self.skipTest("Mocking devices only works on the GPU backend.")
37-
super().setUp()
3838

3939
@jtu.skip_under_pytest("Test must run in an isolated process")
4040
def testMockDeviceCount(self):

tests/mock_gpu_topology_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131
class MockGPUTopologyTest(jtu.JaxTestCase):
3232

3333
def setUp(self):
34+
super().setUp()
3435
if not jtu.test_device_matches(["gpu"]):
3536
self.skipTest("Mocking devices only works on the GPU backend.")
36-
super().setUp()
3737

3838
@jtu.skip_under_pytest("Test must run in an isolated process")
3939
def testMockDeviceCount(self):

0 commit comments

Comments
 (0)