File tree Expand file tree Collapse file tree 4 files changed +7
-3
lines changed
Expand file tree Collapse file tree 4 files changed +7
-3
lines changed Original file line number Diff line number Diff line change @@ -1206,7 +1206,7 @@ def _trace_kernel_to_jaxpr(
12061206 return jaxpr , tuple (consts )
12071207
12081208
1209- _PALLAS_USE_MOSAIC_GPU = config .bool_flag (
1209+ _PALLAS_USE_MOSAIC_GPU = config .bool_state (
12101210 "jax_pallas_use_mosaic_gpu" ,
12111211 default = config .bool_env ("JAX_PALLAS_USE_MOSAIC_GPU" , False ),
12121212 help = (
Original file line number Diff line number Diff line change @@ -215,7 +215,6 @@ jax_multiplatform_test(
215215 "gpu_h100" ,
216216 ],
217217 env = {
218- "JAX_PALLAS_USE_MOSAIC_GPU" : "1" ,
219218 "JAX_PALLAS_VERBOSE_ERRORS" : "0" ,
220219 },
221220 deps = [
Original file line number Diff line number Diff line change 2626from jax import lax
2727from jax ._src import test_util as jtu
2828from jax ._src .pallas .mosaic_gpu import pipeline as mgpu_pipeline
29+ from jax ._src .pallas import pallas_call
2930from jax .experimental import pallas as pl
3031from jax .experimental .pallas import mosaic_gpu as plgpu
3132import jax .numpy as jnp
@@ -59,6 +60,9 @@ class PallasTest(jtu.JaxTestCase):
5960 def setUp (self ):
6061 if not jtu .is_cuda_compute_capability_at_least ("9.0" ):
6162 self .skipTest ("Only works on a GPU with capability >= sm90" )
63+ context_stack = contextlib .ExitStack ()
64+ context_stack .enter_context (pallas_call ._PALLAS_USE_MOSAIC_GPU (True ))
65+ self .addCleanup (context_stack .close )
6266
6367 super ().setUp ()
6468
Original file line number Diff line number Diff line change 3030from jax ._src import linear_util as lu
3131from jax ._src import state
3232from jax ._src import test_util as jtu
33+ from jax ._src .pallas import pallas_call
3334from jax .experimental import pallas as pl
3435from jax .interpreters import partial_eval as pe
3536import jax .numpy as jnp
6162jax .config .parse_flags_with_absl ()
6263jtu .setup_hypothesis (max_examples = 50 )
6364
64- use_mosaic_gpu = jax . config . read ( "jax_pallas_use_mosaic_gpu" )
65+ use_mosaic_gpu = pallas_call . _PALLAS_USE_MOSAIC_GPU . value
6566
6667intx = dtypes .canonicalize_dtype (jnp .int64 )
6768floatx = dtypes .canonicalize_dtype (jnp .float64 )
You can’t perform that action at this time.
0 commit comments