Skip to content

Commit 3f91b4b

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Move jaxlib/{cuda,rocm}_plugin_extension into jaxlib/{cuda/rocm}/
Move the common jaxlib/gpu_plugin_extension into jaxlib/gpu/ Cleanup only, no functional changes intended. PiperOrigin-RevId: 738183402
1 parent 01a110c commit 3f91b4b

File tree

13 files changed

+75
-75
lines changed

13 files changed

+75
-75
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171

7272
export = set_module('jax.numpy')
7373

74-
for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib']:
74+
for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib.cuda']:
7575
try:
7676
cuda_plugin_extension = importlib.import_module(
7777
f'{pkg_name}.cuda_plugin_extension'

jax_plugins/cuda/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
# cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
2626
# preinstalled jax cuda plugin packages.
27-
for pkg_name in ['jax_cuda12_plugin', 'jaxlib']:
27+
for pkg_name in ['jax_cuda12_plugin', 'jaxlib.cuda']:
2828
try:
2929
cuda_plugin_extension = importlib.import_module(
3030
f'{pkg_name}.cuda_plugin_extension'

jax_plugins/rocm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
# rocm_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
2525
# preinstalled jax rocm plugin packages.
26-
for pkg_name in ['jax_rocm60_plugin', 'jaxlib']:
26+
for pkg_name in ['jax_rocm60_plugin', 'jaxlib.cuda']:
2727
try:
2828
rocm_plugin_extension = importlib.import_module(
2929
f'{pkg_name}.rocm_plugin_extension'

jaxlib/BUILD

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -222,62 +222,3 @@ nanobind_extension(
222222
"@xla//third_party/python_runtime:headers",
223223
],
224224
)
225-
226-
cc_library(
227-
name = "gpu_plugin_extension",
228-
srcs = ["gpu_plugin_extension.cc"],
229-
hdrs = ["gpu_plugin_extension.h"],
230-
copts = [
231-
"-fexceptions",
232-
"-fno-strict-aliasing",
233-
],
234-
features = ["-use_header_modules"],
235-
deps = [
236-
":kernel_nanobind_helpers",
237-
"@com_google_absl//absl/status",
238-
"@com_google_absl//absl/status:statusor",
239-
"@com_google_absl//absl/strings:str_format",
240-
"@com_google_absl//absl/strings:string_view",
241-
"@nanobind",
242-
"@xla//xla:util",
243-
"@xla//xla/ffi/api:c_api",
244-
"@xla//xla/pjrt:status_casters",
245-
"@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs",
246-
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
247-
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
248-
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
249-
"@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs",
250-
"@xla//xla/python:py_client_gpu",
251-
"@xla//xla/tsl/python/lib/core:numpy",
252-
],
253-
)
254-
255-
nanobind_extension(
256-
name = "cuda_plugin_extension",
257-
srcs = ["cuda_plugin_extension.cc"],
258-
module_name = "cuda_plugin_extension",
259-
deps = [
260-
":gpu_plugin_extension",
261-
"@com_google_absl//absl/status",
262-
"@com_google_absl//absl/strings",
263-
"@local_config_cuda//cuda:cuda_headers",
264-
"@nanobind",
265-
"@xla//xla/pjrt:status_casters",
266-
"@xla//xla/tsl/cuda:cublas",
267-
"@xla//xla/tsl/cuda:cudart",
268-
],
269-
)
270-
271-
nanobind_extension(
272-
name = "rocm_plugin_extension",
273-
srcs = ["rocm_plugin_extension.cc"],
274-
module_name = "rocm_plugin_extension",
275-
deps = [
276-
":gpu_plugin_extension",
277-
"@com_google_absl//absl/log",
278-
"@com_google_absl//absl/strings",
279-
"@local_config_rocm//rocm:hip",
280-
"@local_config_rocm//rocm:rocm_headers",
281-
"@nanobind",
282-
],
283-
)

jaxlib/cuda/BUILD

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -657,13 +657,29 @@ py_library(
657657
],
658658
)
659659

660+
nanobind_extension(
661+
name = "cuda_plugin_extension",
662+
srcs = ["cuda_plugin_extension.cc"],
663+
module_name = "cuda_plugin_extension",
664+
deps = [
665+
"//jaxlib/gpu:gpu_plugin_extension",
666+
"@com_google_absl//absl/status",
667+
"@com_google_absl//absl/strings",
668+
"@local_config_cuda//cuda:cuda_headers",
669+
"@nanobind",
670+
"@xla//xla/pjrt:status_casters",
671+
"@xla//xla/tsl/cuda:cublas",
672+
"@xla//xla/tsl/cuda:cudart",
673+
],
674+
)
675+
660676
# We cannot nest select and if_cuda_is_configured so we introduce
661677
# a standalone py_library target.
662678
py_library(
663679
name = "gpu_only_test_deps",
664680
# `if_cuda_is_configured` will default to `[]`.
665681
deps = if_cuda_is_configured([
666682
":cuda_gpu_support",
667-
"//jaxlib:cuda_plugin_extension",
683+
":cuda_plugin_extension",
668684
]),
669685
)

jaxlib/cuda_plugin_extension.cc renamed to jaxlib/cuda/cuda_plugin_extension.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ limitations under the License.
2020
#include "absl/status/status.h"
2121
#include "absl/strings/str_cat.h"
2222
#include "third_party/gpus/cuda/include/cuda.h"
23-
#include "jaxlib/gpu_plugin_extension.h"
23+
#include "jaxlib/gpu/gpu_plugin_extension.h"
2424
#include "xla/pjrt/status_casters.h"
2525

2626
namespace nb = nanobind;

jaxlib/gpu/BUILD

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,32 @@ xla_py_proto_library(
9090
visibility = jax_visibility("triton_proto_py_users"),
9191
deps = [":triton_proto"],
9292
)
93+
94+
cc_library(
95+
name = "gpu_plugin_extension",
96+
srcs = ["gpu_plugin_extension.cc"],
97+
hdrs = ["gpu_plugin_extension.h"],
98+
copts = [
99+
"-fexceptions",
100+
"-fno-strict-aliasing",
101+
],
102+
features = ["-use_header_modules"],
103+
deps = [
104+
"//jaxlib:kernel_nanobind_helpers",
105+
"@com_google_absl//absl/status",
106+
"@com_google_absl//absl/status:statusor",
107+
"@com_google_absl//absl/strings:str_format",
108+
"@com_google_absl//absl/strings:string_view",
109+
"@nanobind",
110+
"@xla//xla:util",
111+
"@xla//xla/ffi/api:c_api",
112+
"@xla//xla/pjrt:status_casters",
113+
"@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs",
114+
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
115+
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
116+
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
117+
"@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs",
118+
"@xla//xla/python:py_client_gpu",
119+
"@xla//xla/tsl/python/lib/core:numpy",
120+
],
121+
)

jaxlib/gpu_plugin_extension.cc renamed to jaxlib/gpu/gpu_plugin_extension.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#include "jaxlib/gpu_plugin_extension.h"
16+
#include "jaxlib/gpu/gpu_plugin_extension.h"
1717

1818
#include <cstddef>
1919
#include <cstdint>

jaxlib/gpu_plugin_extension.h renamed to jaxlib/gpu/gpu_plugin_extension.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#ifndef JAXLIB_GPU_PLUGIN_EXTENSION_H_
17-
#define JAXLIB_GPU_PLUGIN_EXTENSION_H_
16+
#ifndef JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_
17+
#define JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_
1818

1919
#include "nanobind/nanobind.h"
2020

@@ -24,4 +24,4 @@ void BuildGpuPluginExtension(nanobind::module_& m);
2424

2525
} // namespace xla
2626

27-
#endif // JAXLIB_GPU_PLUGIN_EXTENSION_H_
27+
#endif // JAXLIB_GPU_GPU_PLUGIN_EXTENSION_H_

jaxlib/rocm/BUILD

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,11 +555,25 @@ py_library(
555555
],
556556
)
557557

558+
nanobind_extension(
559+
name = "rocm_plugin_extension",
560+
srcs = ["rocm_plugin_extension.cc"],
561+
module_name = "rocm_plugin_extension",
562+
deps = [
563+
"//jaxlib/gpu:gpu_plugin_extension",
564+
"@com_google_absl//absl/log",
565+
"@com_google_absl//absl/strings",
566+
"@local_config_rocm//rocm:hip",
567+
"@local_config_rocm//rocm:rocm_headers",
568+
"@nanobind",
569+
],
570+
)
571+
558572
py_library(
559573
name = "gpu_only_test_deps",
560574
# `if_rocm_is_configured` will default to `[]`.
561575
deps = if_rocm_is_configured([
562576
":rocm_gpu_support",
563-
"//jaxlib:rocm_plugin_extension",
577+
":rocm_plugin_extension",
564578
]),
565579
)

0 commit comments

Comments
 (0)