@@ -12,135 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212See the License for the specific language governing permissions and
1313limitations under the License.
1414==============================================================================*/
15- #include < Python.h>
1615
17- #include < cstddef >
16+ #include < cstdint >
1817#include < string>
19- #include < string_view>
20- #include < utility>
2118
2219#include " nanobind/nanobind.h"
2320#include " absl/status/status.h"
21+ #include " absl/strings/str_cat.h"
2422#include " third_party/gpus/cuda/include/cuda.h"
25- #include " jaxlib/kernel_nanobind_helpers.h"
26- #include " xla/ffi/api/c_api.h"
27- #include " xla/pjrt/c/pjrt_c_api.h"
28- #include " xla/pjrt/c/pjrt_c_api_gpu_extension.h"
29- #include " xla/pjrt/c/pjrt_c_api_helpers.h"
23+ #include " jaxlib/gpu_plugin_extension.h"
3024#include " xla/pjrt/status_casters.h"
31- #include " xla/python/py_client_gpu.h"
32- #include " xla/tsl/python/lib/core/numpy.h"
33- #include " xla/util.h"
3425
3526namespace nb = nanobind;
3627
3728namespace xla {
3829namespace {
39- absl::Status RegisterCustomCallTarget (const PJRT_Api* c_api,
40- const char * fn_name_c_str,
41- size_t fn_name_size, nb::object fn,
42- int api_version,
43- XLA_FFI_Handler_Traits traits) {
44- if (c_api->extension_start == nullptr ) {
45- return Unimplemented (" The plugin does not have extension." );
46- }
47- const PJRT_Extension_Base* next =
48- reinterpret_cast <const PJRT_Extension_Base*>(c_api->extension_start );
49- while (next != nullptr &&
50- next->type !=
51- PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) {
52- next = next->next ;
53- }
54- if (next == nullptr ) {
55- return Unimplemented (" The plugin does not have a custom call extension." );
56- }
57- PJRT_Gpu_Register_Custom_Call* register_custom_call =
58- reinterpret_cast <const PJRT_Gpu_Custom_Call*>(next)->custom_call ;
59-
60- if (traits != 0 ) {
61- return Unimplemented (" The plugin does not support custom call traits." );
62- }
63-
64- PJRT_Gpu_Register_Custom_Call_Args args;
65- args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE;
66- args.function_name = fn_name_c_str;
67- args.function_name_size = fn_name_size;
68-
69- #if PJRT_API_GPU_EXTENSION_VERSION >= 1
70- args.api_version = api_version;
71- #endif
72-
73- auto as_capsule = [](nb::object obj) -> absl::StatusOr<nb::capsule> {
74- nb::capsule capsule;
75- if (!nb::try_cast<nb::capsule>(obj, capsule)) {
76- return absl::InvalidArgumentError (
77- " Custom call target registration requires handlers as PyCapsules" );
78- }
79- return capsule;
80- };
81-
82- #if PJRT_API_GPU_EXTENSION_VERSION <= 1
83- TF_ASSIGN_OR_RETURN (nb::capsule fn_execute, as_capsule (fn));
84- args.custom_call_function = fn_execute.data ();
85- RETURN_STATUS_IF_PJRT_ERROR (register_custom_call (&args), c_api);
86- return absl::OkStatus ();
87- #else
88- args.handler_instantiate = nullptr ;
89- args.handler_prepare = nullptr ;
90- args.handler_initialize = nullptr ;
91- args.handler_execute = nullptr ;
92-
93- // Register legacy custom call target (untyped void* API).
94- if (api_version == 0 ) {
95- TF_ASSIGN_OR_RETURN (nb::capsule capsule_execute, as_capsule (fn));
96- args.handler_execute = capsule_execute.data ();
97- RETURN_STATUS_IF_PJRT_ERROR (register_custom_call (&args), c_api);
98- return absl::OkStatus ();
99- }
100-
101- // Register XLA FFI handler (typed API with explicit function signatures).
102- if (api_version == 1 ) {
103- auto capsule_execute = as_capsule (fn);
104- if (capsule_execute.ok ()) {
105- args.handler_execute = capsule_execute->data ();
106- RETURN_STATUS_IF_PJRT_ERROR (register_custom_call (&args), c_api);
107- return absl::OkStatus ();
108- }
109-
110- nb::dict bundle;
111- if (nb::try_cast<nb::dict>(fn, bundle)) {
112- auto handler = [&](const char * name) -> absl::StatusOr<void *> {
113- if (!bundle.contains (name)) return nullptr ;
114- TF_ASSIGN_OR_RETURN (nb::capsule capsule, as_capsule (bundle[name]));
115- return capsule.data ();
116- };
117-
118- TF_ASSIGN_OR_RETURN (args.handler_instantiate , handler (" instantiate" ));
119- TF_ASSIGN_OR_RETURN (args.handler_prepare , handler (" prepare" ));
120- TF_ASSIGN_OR_RETURN (args.handler_initialize , handler (" initialize" ));
121- TF_ASSIGN_OR_RETURN (args.handler_execute , handler (" execute" ));
122- RETURN_STATUS_IF_PJRT_ERROR (register_custom_call (&args), c_api);
123- return absl::OkStatus ();
124- }
125-
126- return absl::InvalidArgumentError (
127- " Unsupported custom call target type for api_version=1" );
128- }
129-
130- return absl::UnimplementedError (absl::StrFormat (
131- " API version %d is not supported by RegisterCustomCallTarget. "
132- " Supported versions are 0 and 1." ,
133- api_version));
134- #endif
135- }
136-
137- nb::dict Registrations () {
138- nb::dict dict;
139- dict[" xla_python_gpu_callback" ] =
140- jax::EncapsulateFunction (xla::XlaPythonGpuCallback);
141- return dict;
142- }
143-
14430static std::string ToString (CUresult result) {
14531 const char * error_name;
14632 if (cuGetErrorName (result, &error_name)) {
@@ -155,31 +41,7 @@ static std::string ToString(CUresult result) {
15541} // namespace
15642
15743NB_MODULE (cuda_plugin_extension, m) {
158- tsl::ImportNumpy ();
159- m.def (
160- " register_custom_call_target" ,
161- [](nb::capsule c_api, nb::object fn_name_py, nb::object fn,
162- nb::str xla_platform_name, int api_version,
163- XLA_FFI_Handler_Traits traits) {
164- const char * fn_name_c_str;
165- size_t fn_name_size;
166- nb::str fn_name_bn_str;
167- if (nb::try_cast<nb::str>(fn_name_py, fn_name_bn_str)) {
168- fn_name_c_str = fn_name_bn_str.c_str ();
169- fn_name_size = nb::len (fn_name_bn_str);
170- } else {
171- nb::bytes bytes = nb::cast<nb::bytes>(fn_name_py);
172- fn_name_c_str = bytes.c_str ();
173- fn_name_size = bytes.size ();
174- }
175- xla::ThrowIfError (RegisterCustomCallTarget (
176- static_cast <const PJRT_Api*>(c_api.data ()), fn_name_c_str,
177- fn_name_size, std::move (fn), api_version, traits));
178- },
179- nb::arg (" c_api" ), nb::arg (" fn_name" ), nb::arg (" fn" ),
180- nb::arg (" xla_platform_name" ), nb::arg (" api_version" ) = 0 ,
181- nb::arg (" traits" ) = 0 );
182- m.def (" registrations" , &Registrations);
44+ BuildGpuPluginExtension (m);
18345 m.def (
18446 " get_device_ordinal" ,
18547 [](std::intptr_t data_value) {
0 commit comments