Skip to content

Commit 593143e

Browse files
dfmGoogle-ML-Automation
authored andcommitted
Deduplicate some GPU plugin definition code.
The `jaxlib/cuda_plugin_extension.cc` and `jaxlib/rocm_plugin_extension.cc` files were nearly identical so this change consolidates the shared implementation into a single target. PiperOrigin-RevId: 704785926
1 parent 210bd30 commit 593143e

File tree

5 files changed

+246
-305
lines changed

5 files changed

+246
-305
lines changed

jaxlib/BUILD

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -208,27 +208,47 @@ pybind_extension(
208208
],
209209
)
210210

211-
pybind_extension(
212-
name = "cuda_plugin_extension",
213-
srcs = ["cuda_plugin_extension.cc"],
214-
module_name = "cuda_plugin_extension",
211+
cc_library(
212+
name = "gpu_plugin_extension",
213+
srcs = ["gpu_plugin_extension.cc"],
214+
hdrs = ["gpu_plugin_extension.h"],
215+
copts = [
216+
"-fexceptions",
217+
"-fno-strict-aliasing",
218+
],
219+
features = ["-use_header_modules"],
215220
deps = [
221+
":kernel_nanobind_helpers",
216222
"@com_google_absl//absl/status",
217-
"@nanobind",
218-
"//jaxlib:kernel_nanobind_helpers",
219-
"@xla//third_party/python_runtime:headers",
223+
"@com_google_absl//absl/status:statusor",
224+
"@com_google_absl//absl/strings:str_format",
220225
"@local_config_cuda//cuda:cuda_headers",
226+
"@nanobind",
227+
"@tsl//tsl/platform:statusor",
221228
"@xla//xla:util",
222229
"@xla//xla/ffi/api:c_api",
223230
"@xla//xla/pjrt:status_casters",
224231
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
225232
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
226233
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
227-
# TODO(jieying): move to jaxlib after py_client_gpu is separated from py_client
228234
"@xla//xla/python:py_client_gpu",
235+
"@xla//xla/tsl/python/lib/core:numpy",
236+
],
237+
)
238+
239+
pybind_extension(
240+
name = "cuda_plugin_extension",
241+
srcs = ["cuda_plugin_extension.cc"],
242+
module_name = "cuda_plugin_extension",
243+
deps = [
244+
":gpu_plugin_extension",
245+
"@com_google_absl//absl/status",
246+
"@com_google_absl//absl/strings",
247+
"@local_config_cuda//cuda:cuda_headers",
248+
"@nanobind",
249+
"@xla//xla/pjrt:status_casters",
229250
"@xla//xla/tsl/cuda:cublas",
230251
"@xla//xla/tsl/cuda:cudart",
231-
"@xla//xla/tsl/python/lib/core:numpy",
232252
],
233253
)
234254

@@ -237,20 +257,12 @@ pybind_extension(
237257
srcs = ["rocm_plugin_extension.cc"],
238258
module_name = "rocm_plugin_extension",
239259
deps = [
240-
"//jaxlib:kernel_nanobind_helpers",
241-
"@com_google_absl//absl/status",
260+
":gpu_plugin_extension",
261+
"@com_google_absl//absl/log",
262+
"@com_google_absl//absl/strings",
242263
"@local_config_rocm//rocm:hip",
243264
"@local_config_rocm//rocm:rocm_headers",
244265
"@nanobind",
245-
"@xla//third_party/python_runtime:headers",
246-
"@xla//xla:util",
247-
"@xla//xla/ffi/api:c_api",
248-
"@xla//xla/pjrt:status_casters",
249-
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
250-
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
251-
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
252-
"@xla//xla/python:py_client_gpu",
253-
"@xla//xla/tsl/python/lib/core:numpy",
254266
],
255267
)
256268

jaxlib/cuda_plugin_extension.cc

Lines changed: 4 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -12,135 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
15-
#include <Python.h>
1615

17-
#include <cstddef>
16+
#include <cstdint>
1817
#include <string>
19-
#include <string_view>
20-
#include <utility>
2118

2219
#include "nanobind/nanobind.h"
2320
#include "absl/status/status.h"
21+
#include "absl/strings/str_cat.h"
2422
#include "third_party/gpus/cuda/include/cuda.h"
25-
#include "jaxlib/kernel_nanobind_helpers.h"
26-
#include "xla/ffi/api/c_api.h"
27-
#include "xla/pjrt/c/pjrt_c_api.h"
28-
#include "xla/pjrt/c/pjrt_c_api_gpu_extension.h"
29-
#include "xla/pjrt/c/pjrt_c_api_helpers.h"
23+
#include "jaxlib/gpu_plugin_extension.h"
3024
#include "xla/pjrt/status_casters.h"
31-
#include "xla/python/py_client_gpu.h"
32-
#include "xla/tsl/python/lib/core/numpy.h"
33-
#include "xla/util.h"
3425

3526
namespace nb = nanobind;
3627

3728
namespace xla {
3829
namespace {
39-
absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api,
40-
const char* fn_name_c_str,
41-
size_t fn_name_size, nb::object fn,
42-
int api_version,
43-
XLA_FFI_Handler_Traits traits) {
44-
if (c_api->extension_start == nullptr) {
45-
return Unimplemented("The plugin does not have extension.");
46-
}
47-
const PJRT_Extension_Base* next =
48-
reinterpret_cast<const PJRT_Extension_Base*>(c_api->extension_start);
49-
while (next != nullptr &&
50-
next->type !=
51-
PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) {
52-
next = next->next;
53-
}
54-
if (next == nullptr) {
55-
return Unimplemented("The plugin does not have a custom call extension.");
56-
}
57-
PJRT_Gpu_Register_Custom_Call* register_custom_call =
58-
reinterpret_cast<const PJRT_Gpu_Custom_Call*>(next)->custom_call;
59-
60-
if (traits != 0) {
61-
return Unimplemented("The plugin does not support custom call traits.");
62-
}
63-
64-
PJRT_Gpu_Register_Custom_Call_Args args;
65-
args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE;
66-
args.function_name = fn_name_c_str;
67-
args.function_name_size = fn_name_size;
68-
69-
#if PJRT_API_GPU_EXTENSION_VERSION >= 1
70-
args.api_version = api_version;
71-
#endif
72-
73-
auto as_capsule = [](nb::object obj) -> absl::StatusOr<nb::capsule> {
74-
nb::capsule capsule;
75-
if (!nb::try_cast<nb::capsule>(obj, capsule)) {
76-
return absl::InvalidArgumentError(
77-
"Custom call target registration requires handlers as PyCapsules");
78-
}
79-
return capsule;
80-
};
81-
82-
#if PJRT_API_GPU_EXTENSION_VERSION <= 1
83-
TF_ASSIGN_OR_RETURN(nb::capsule fn_execute, as_capsule(fn));
84-
args.custom_call_function = fn_execute.data();
85-
RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api);
86-
return absl::OkStatus();
87-
#else
88-
args.handler_instantiate = nullptr;
89-
args.handler_prepare = nullptr;
90-
args.handler_initialize = nullptr;
91-
args.handler_execute = nullptr;
92-
93-
// Register legacy custom call target (untyped void* API).
94-
if (api_version == 0) {
95-
TF_ASSIGN_OR_RETURN(nb::capsule capsule_execute, as_capsule(fn));
96-
args.handler_execute = capsule_execute.data();
97-
RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api);
98-
return absl::OkStatus();
99-
}
100-
101-
// Register XLA FFI handler (typed API with explicit function signatures).
102-
if (api_version == 1) {
103-
auto capsule_execute = as_capsule(fn);
104-
if (capsule_execute.ok()) {
105-
args.handler_execute = capsule_execute->data();
106-
RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api);
107-
return absl::OkStatus();
108-
}
109-
110-
nb::dict bundle;
111-
if (nb::try_cast<nb::dict>(fn, bundle)) {
112-
auto handler = [&](const char* name) -> absl::StatusOr<void*> {
113-
if (!bundle.contains(name)) return nullptr;
114-
TF_ASSIGN_OR_RETURN(nb::capsule capsule, as_capsule(bundle[name]));
115-
return capsule.data();
116-
};
117-
118-
TF_ASSIGN_OR_RETURN(args.handler_instantiate, handler("instantiate"));
119-
TF_ASSIGN_OR_RETURN(args.handler_prepare, handler("prepare"));
120-
TF_ASSIGN_OR_RETURN(args.handler_initialize, handler("initialize"));
121-
TF_ASSIGN_OR_RETURN(args.handler_execute, handler("execute"));
122-
RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api);
123-
return absl::OkStatus();
124-
}
125-
126-
return absl::InvalidArgumentError(
127-
"Unsupported custom call target type for api_version=1");
128-
}
129-
130-
return absl::UnimplementedError(absl::StrFormat(
131-
"API version %d is not supported by RegisterCustomCallTarget. "
132-
"Supported versions are 0 and 1.",
133-
api_version));
134-
#endif
135-
}
136-
137-
nb::dict Registrations() {
138-
nb::dict dict;
139-
dict["xla_python_gpu_callback"] =
140-
jax::EncapsulateFunction(xla::XlaPythonGpuCallback);
141-
return dict;
142-
}
143-
14430
static std::string ToString(CUresult result) {
14531
const char* error_name;
14632
if (cuGetErrorName(result, &error_name)) {
@@ -155,31 +41,7 @@ static std::string ToString(CUresult result) {
15541
} // namespace
15642

15743
NB_MODULE(cuda_plugin_extension, m) {
158-
tsl::ImportNumpy();
159-
m.def(
160-
"register_custom_call_target",
161-
[](nb::capsule c_api, nb::object fn_name_py, nb::object fn,
162-
nb::str xla_platform_name, int api_version,
163-
XLA_FFI_Handler_Traits traits) {
164-
const char* fn_name_c_str;
165-
size_t fn_name_size;
166-
nb::str fn_name_bn_str;
167-
if (nb::try_cast<nb::str>(fn_name_py, fn_name_bn_str)) {
168-
fn_name_c_str = fn_name_bn_str.c_str();
169-
fn_name_size = nb::len(fn_name_bn_str);
170-
} else{
171-
nb::bytes bytes = nb::cast<nb::bytes>(fn_name_py);
172-
fn_name_c_str = bytes.c_str();
173-
fn_name_size = bytes.size();
174-
}
175-
xla::ThrowIfError(RegisterCustomCallTarget(
176-
static_cast<const PJRT_Api*>(c_api.data()), fn_name_c_str,
177-
fn_name_size, std::move(fn), api_version, traits));
178-
},
179-
nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"),
180-
nb::arg("xla_platform_name"), nb::arg("api_version") = 0,
181-
nb::arg("traits") = 0);
182-
m.def("registrations", &Registrations);
44+
BuildGpuPluginExtension(m);
18345
m.def(
18446
"get_device_ordinal",
18547
[](std::intptr_t data_value) {

0 commit comments

Comments
 (0)