Skip to content

Commit 59e480d

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Skip Mosaic GPU tests if jax_pallas_use_mosaic_gpu flag is not set.
PiperOrigin-RevId: 738906466
1 parent 5745ff5 commit 59e480d

File tree

4 files changed

+7
-3
lines changed

4 files changed

+7
-3
lines changed

jax/_src/pallas/pallas_call.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1206,7 +1206,7 @@ def _trace_kernel_to_jaxpr(
12061206
return jaxpr, tuple(consts)
12071207

12081208

1209-
_PALLAS_USE_MOSAIC_GPU = config.bool_flag(
1209+
_PALLAS_USE_MOSAIC_GPU = config.bool_state(
12101210
"jax_pallas_use_mosaic_gpu",
12111211
default=config.bool_env("JAX_PALLAS_USE_MOSAIC_GPU", False),
12121212
help=(

tests/pallas/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,6 @@ jax_multiplatform_test(
215215
"gpu_h100",
216216
],
217217
env = {
218-
"JAX_PALLAS_USE_MOSAIC_GPU": "1",
219218
"JAX_PALLAS_VERBOSE_ERRORS": "0",
220219
},
221220
deps = [

tests/pallas/mosaic_gpu_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from jax import lax
2727
from jax._src import test_util as jtu
2828
from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline
29+
from jax._src.pallas import pallas_call
2930
from jax.experimental import pallas as pl
3031
from jax.experimental.pallas import mosaic_gpu as plgpu
3132
import jax.numpy as jnp
@@ -59,6 +60,9 @@ class PallasTest(jtu.JaxTestCase):
5960
def setUp(self):
6061
if not jtu.is_cuda_compute_capability_at_least("9.0"):
6162
self.skipTest("Only works on a GPU with capability >= sm90")
63+
context_stack = contextlib.ExitStack()
64+
context_stack.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True))
65+
self.addCleanup(context_stack.close)
6266

6367
super().setUp()
6468

tests/pallas/ops_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from jax._src import linear_util as lu
3131
from jax._src import state
3232
from jax._src import test_util as jtu
33+
from jax._src.pallas import pallas_call
3334
from jax.experimental import pallas as pl
3435
from jax.interpreters import partial_eval as pe
3536
import jax.numpy as jnp
@@ -61,7 +62,7 @@
6162
jax.config.parse_flags_with_absl()
6263
jtu.setup_hypothesis(max_examples=50)
6364

64-
use_mosaic_gpu = jax.config.read("jax_pallas_use_mosaic_gpu")
65+
use_mosaic_gpu = pallas_call._PALLAS_USE_MOSAIC_GPU.value
6566

6667
intx = dtypes.canonicalize_dtype(jnp.int64)
6768
floatx = dtypes.canonicalize_dtype(jnp.float64)

0 commit comments

Comments
 (0)