Skip to content

Commit eb54cd2

Browse files
Remove GPU-specific dependencies from backend-independent tests.
The GPU-specific deps were added to the backend-independent tests by mistake [here](jax-ml#27113). These tests should pass using `jax` and `jaxlib` wheels only. PiperOrigin-RevId: 741663266
1 parent 6fba4ec commit eb54cd2

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

jaxlib/jax.bzl

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,8 @@ def if_building_jaxlib(
255255
"//conditions:default": [],
256256
})
257257

258-
def _get_test_deps(deps):
259-
jaxlib_build_deps = [
258+
def _get_test_deps(deps, backend_independent):
259+
gpu_build_deps = [
260260
"//jaxlib/cuda:gpu_only_test_deps",
261261
"//jaxlib/rocm:gpu_only_test_deps",
262262
"//jax_plugins:gpu_plugin_only_test_deps",
@@ -273,12 +273,21 @@ def _get_test_deps(deps):
273273
"//jaxlib/tools:jaxlib_py_import",
274274
]
275275

276+
if backend_independent:
277+
jaxlib_build_deps = deps
278+
gpu_pypi_wheel_deps = _CPU_PYPI_WHEEL_DEPS
279+
gpu_py_import_deps = cpu_py_imports
280+
else:
281+
jaxlib_build_deps = gpu_build_deps + deps
282+
gpu_pypi_wheel_deps = _GPU_PYPI_WHEEL_DEPS
283+
gpu_py_import_deps = gpu_py_imports
284+
276285
return select({
277-
"//jax:enable_jaxlib_build": jaxlib_build_deps + deps,
286+
"//jax:enable_jaxlib_build": jaxlib_build_deps,
278287
"//jax_plugins/cuda:disable_jaxlib_for_cpu_build": _CPU_PYPI_WHEEL_DEPS,
279-
"//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": _GPU_PYPI_WHEEL_DEPS,
288+
"//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": gpu_pypi_wheel_deps,
280289
"//jax_plugins/cuda:enable_py_import_for_cpu_build": cpu_py_imports,
281-
"//jax_plugins/cuda:enable_py_import_for_cuda12_build": gpu_py_imports,
290+
"//jax_plugins/cuda:enable_py_import_for_cuda12_build": gpu_py_import_deps,
282291
})
283292

284293
# buildifier: disable=function-docstring
@@ -334,7 +343,7 @@ def jax_multiplatform_test(
334343
deps = _get_test_deps([
335344
"//jax",
336345
"//jax:test_util",
337-
] + deps),
346+
] + deps, backend_independent = False),
338347
data = data,
339348
shard_count = test_shards,
340349
tags = test_tags,
@@ -629,15 +638,15 @@ def jax_py_test(
629638
if "PYTHONWARNINGS" not in env:
630639
env["PYTHONWARNINGS"] = "error"
631640
deps = kwargs.get("deps", [])
632-
kwargs.pop("deps")
633-
test_deps = _get_test_deps(deps)
634-
py_test(name = name, env = env, deps = test_deps, **kwargs)
641+
test_deps = _get_test_deps(deps, backend_independent = True)
642+
kwargs["deps"] = test_deps
643+
py_test(name = name, env = env, **kwargs)
635644

636645
def pytype_test(name, **kwargs):
637646
deps = kwargs.get("deps", [])
638-
kwargs.pop("deps")
639-
test_deps = _get_test_deps(deps)
640-
native.py_test(name = name, deps = test_deps, **kwargs)
647+
test_deps = _get_test_deps(deps, backend_independent = True)
648+
kwargs["deps"] = test_deps
649+
native.py_test(name = name, **kwargs)
641650

642651
def if_oss(oss_value, google_value = []):
643652
"""Returns one of the arguments based on the non-configurable build env.

0 commit comments

Comments
 (0)