@@ -255,8 +255,8 @@ def if_building_jaxlib(
255255 "//conditions:default" : [],
256256 })
257257
258- def _get_test_deps (deps ):
259- jaxlib_build_deps = [
258+ def _get_test_deps (deps , backend_independent ):
259+ gpu_build_deps = [
260260 "//jaxlib/cuda:gpu_only_test_deps" ,
261261 "//jaxlib/rocm:gpu_only_test_deps" ,
262262 "//jax_plugins:gpu_plugin_only_test_deps" ,
@@ -273,12 +273,21 @@ def _get_test_deps(deps):
273273 "//jaxlib/tools:jaxlib_py_import" ,
274274 ]
275275
276+ if backend_independent :
277+ jaxlib_build_deps = deps
278+ gpu_pypi_wheel_deps = _CPU_PYPI_WHEEL_DEPS
279+ gpu_py_import_deps = cpu_py_imports
280+ else :
281+ jaxlib_build_deps = gpu_build_deps + deps
282+ gpu_pypi_wheel_deps = _GPU_PYPI_WHEEL_DEPS
283+ gpu_py_import_deps = gpu_py_imports
284+
276285 return select ({
277- "//jax:enable_jaxlib_build" : jaxlib_build_deps + deps ,
286+ "//jax:enable_jaxlib_build" : jaxlib_build_deps ,
278287 "//jax_plugins/cuda:disable_jaxlib_for_cpu_build" : _CPU_PYPI_WHEEL_DEPS ,
279- "//jax_plugins/cuda:disable_jaxlib_for_cuda12_build" : _GPU_PYPI_WHEEL_DEPS ,
288+ "//jax_plugins/cuda:disable_jaxlib_for_cuda12_build" : gpu_pypi_wheel_deps ,
280289 "//jax_plugins/cuda:enable_py_import_for_cpu_build" : cpu_py_imports ,
281- "//jax_plugins/cuda:enable_py_import_for_cuda12_build" : gpu_py_imports ,
290+ "//jax_plugins/cuda:enable_py_import_for_cuda12_build" : gpu_py_import_deps ,
282291 })
283292
284293# buildifier: disable=function-docstring
@@ -334,7 +343,7 @@ def jax_multiplatform_test(
334343 deps = _get_test_deps ([
335344 "//jax" ,
336345 "//jax:test_util" ,
337- ] + deps ),
346+ ] + deps , backend_independent = False ),
338347 data = data ,
339348 shard_count = test_shards ,
340349 tags = test_tags ,
@@ -629,15 +638,15 @@ def jax_py_test(
629638 if "PYTHONWARNINGS" not in env :
630639 env ["PYTHONWARNINGS" ] = "error"
631640 deps = kwargs .get ("deps" , [])
632- kwargs . pop ( " deps" )
633- test_deps = _get_test_deps ( deps )
634- py_test (name = name , env = env , deps = test_deps , ** kwargs )
641+ test_deps = _get_test_deps ( deps , backend_independent = True )
642+ kwargs [ "deps" ] = test_deps
643+ py_test (name = name , env = env , ** kwargs )
635644
636645def pytype_test (name , ** kwargs ):
637646 deps = kwargs .get ("deps" , [])
638- kwargs . pop ( " deps" )
639- test_deps = _get_test_deps ( deps )
640- native .py_test (name = name , deps = test_deps , ** kwargs )
647+ test_deps = _get_test_deps ( deps , backend_independent = True )
648+ kwargs [ "deps" ] = test_deps
649+ native .py_test (name = name , ** kwargs )
641650
642651def if_oss (oss_value , google_value = []):
643652 """Returns one of the arguments based on the non-configurable build env.
0 commit comments