Skip to content

Commit 0d047a1

Browse files
author
jax authors
committed
Merge pull request #21718 from jakevdp:pallas-config
PiperOrigin-RevId: 641349981
2 parents 44a13c9 + a2c31f4 commit 0d047a1

File tree

4 files changed

+6
-15
lines changed

4 files changed

+6
-15
lines changed

tests/mosaic/gpu_test.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747

4848

4949
# ruff: noqa: F405
50-
config.update("jax_traceback_filtering", "off")
5150
config.parse_flags_with_absl()
5251

5352
def nd_loop(bounds, body, *, _idxs = ()):
@@ -164,16 +163,9 @@ def setUp(self):
164163
self.skipTest("Only works on GPU with capability >= sm90")
165164
super().setUp()
166165
self.prng = np.random.default_rng(1234)
167-
self.ctx = mlir.make_ir_context()
168-
self.ctx.__enter__()
169-
self.loc = ir.Location.unknown()
170-
self.loc.__enter__()
171-
172-
def tearDown(self):
173-
self.loc.__exit__(None, None, None)
174-
self.ctx.__exit__(None, None, None)
175-
del self.loc, self.ctx
176-
super().tearDown()
166+
self.enter_context(jtu.global_config_context(jax_traceback_filtering="off"))
167+
self.enter_context(mlir.make_ir_context())
168+
self.enter_context(ir.Location.unknown())
177169

178170

179171
class TestUtilTest(TestCase):

tests/mosaic/matmul_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@
3232
from jax.experimental.mosaic.gpu.examples import matmul
3333

3434

35-
config.update("jax_traceback_filtering", "off")
3635
config.parse_flags_with_absl()
3736
os.environ["XLA_FLAGS"] = (
3837
os.environ.get("XLA_FLAGS", "") + " --xla_gpu_autotune_level=0")
3938

4039

40+
@jtu.with_config(jax_traceback_filtering="off")
4141
class MatmulTestCase(jtu.JaxTestCase):
4242

4343
def setUp(self):

tests/pallas/gpu_attention_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@
3030
# pylint: disable=no-value-for-parameter
3131

3232

33-
config.update("jax_traceback_filtering", "off")
3433
config.parse_flags_with_absl()
3534

3635

36+
@jtu.with_config(jax_traceback_filtering="off")
3737
class DecodeAttentionTest(jtu.JaxTestCase):
3838

3939
def setUp(self):

tests/pallas/pallas_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,6 @@
4949
# TODO(sharadmv): Update signatures of pallas_call to correct inputs/outputs.
5050
# pylint: disable=no-value-for-parameter
5151

52-
53-
config.update("jax_traceback_filtering", "off")
5452
config.parse_flags_with_absl()
5553

5654
@functools.partial(jax.jit, static_argnames=["bm", "bn", "gm", "bk",
@@ -121,6 +119,7 @@ def body(i, acc_ref):
121119
return matmul_kernel(x, y)
122120

123121

122+
@jtu.with_config(jax_traceback_filtering="off")
124123
class PallasTest(jtu.JaxTestCase):
125124
INTERPRET = False
126125

0 commit comments

Comments
 (0)