Skip to content

Commit 28afd25

Browse files
dfmGoogle-ML-Automation
authored andcommitted
Add FFI example demonstrating the use of XLA's FFI state.
Support for this was added in JAX v0.5.0. PiperOrigin-RevId: 722597704
1 parent cb188a0 commit 28afd25

File tree

4 files changed

+136
-2
lines changed

4 files changed

+136
-2
lines changed

examples/ffi/CMakeLists.txt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,29 @@ message(STATUS "XLA include directory: ${XLA_DIR}")
1313
find_package(nanobind CONFIG REQUIRED)
1414

1515
set(
16-
JAX_FFI_EXAMPLE_PROJECTS
16+
JAX_FFI_EXAMPLE_CPU_PROJECTS
1717
"rms_norm"
1818
"cpu_examples"
1919
)
2020

21-
foreach(PROJECT ${JAX_FFI_EXAMPLE_PROJECTS})
21+
foreach(PROJECT ${JAX_FFI_EXAMPLE_CPU_PROJECTS})
2222
nanobind_add_module("_${PROJECT}" NB_STATIC "src/jax_ffi_example/${PROJECT}.cc")
2323
target_include_directories("_${PROJECT}" PUBLIC ${XLA_DIR})
2424
install(TARGETS "_${PROJECT}" LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
2525
endforeach()
2626

2727
if(JAX_FFI_EXAMPLE_ENABLE_CUDA)
2828
enable_language(CUDA)
29+
find_package(CUDAToolkit REQUIRED)
30+
2931
add_library(_cuda_examples SHARED "src/jax_ffi_example/cuda_examples.cu")
3032
set_target_properties(_cuda_examples PROPERTIES POSITION_INDEPENDENT_CODE ON
3133
CUDA_STANDARD 17)
3234
target_include_directories(_cuda_examples PUBLIC ${XLA_DIR})
3335
install(TARGETS _cuda_examples LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
36+
37+
nanobind_add_module(_gpu_examples NB_STATIC "src/jax_ffi_example/gpu_examples.cc")
38+
target_include_directories(_gpu_examples PUBLIC ${XLA_DIR})
39+
target_link_libraries(_gpu_examples PRIVATE CUDA::cudart)
40+
install(TARGETS _gpu_examples LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
3441
endif()
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/* Copyright 2025 The JAX Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <cstdint>
17+
#include <memory>
18+
19+
#include "nanobind/nanobind.h"
20+
#include "cuda_runtime_api.h"
21+
#include "xla/ffi/api/ffi.h"
22+
23+
namespace nb = nanobind;
24+
namespace ffi = xla::ffi;
25+
26+
struct State {
27+
static xla::ffi::TypeId id;
28+
explicit State(int32_t value) : value(value) {}
29+
int32_t value;
30+
};
31+
ffi::TypeId State::id = {};
32+
33+
static ffi::ErrorOr<std::unique_ptr<State>> StateInstantiate() {
34+
return std::make_unique<State>(42);
35+
}
36+
37+
static ffi::Error StateExecute(cudaStream_t stream, State* state,
38+
ffi::ResultBufferR0<ffi::S32> out) {
39+
cudaMemcpyAsync(out->typed_data(), &state->value, sizeof(int32_t),
40+
cudaMemcpyHostToDevice, stream);
41+
cudaStreamSynchronize(stream);
42+
return ffi::Error::Success();
43+
}
44+
45+
XLA_FFI_DEFINE_HANDLER(kStateInstantiate, StateInstantiate,
46+
ffi::Ffi::BindInstantiate());
47+
XLA_FFI_DEFINE_HANDLER(kStateExecute, StateExecute,
48+
ffi::Ffi::Bind()
49+
.Ctx<ffi::PlatformStream<cudaStream_t>>()
50+
.Ctx<ffi::State<State>>()
51+
.Ret<ffi::BufferR0<ffi::S32>>());
52+
53+
NB_MODULE(_gpu_examples, m) {
54+
m.def("type_id",
55+
[]() { return nb::capsule(reinterpret_cast<void*>(&State::id)); });
56+
m.def("handler", []() {
57+
nb::dict d;
58+
d["instantiate"] = nb::capsule(reinterpret_cast<void*>(kStateInstantiate));
59+
d["execute"] = nb::capsule(reinterpret_cast<void*>(kStateExecute));
60+
return d;
61+
});
62+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import jax
16+
from jax_ffi_example import _gpu_examples
17+
import jax.numpy as jnp
18+
19+
jax.ffi.register_ffi_target("state", _gpu_examples.handler(), platform="CUDA")
20+
jax.ffi.register_ffi_type_id("state", _gpu_examples.type_id(), platform="CUDA")
21+
22+
23+
def read_state():
24+
return jax.ffi.ffi_call("state", jax.ShapeDtypeStruct((), jnp.int32))()
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from absl.testing import absltest
16+
import jax
17+
from jax._src import test_util as jtu
18+
19+
jax.config.parse_flags_with_absl()
20+
21+
22+
class GpuExamplesTest(jtu.JaxTestCase):
23+
24+
25+
def setUp(self):
26+
super().setUp()
27+
if not jtu.test_device_matches(["cuda"]):
28+
self.skipTest("Unsupported platform")
29+
30+
# Import here to avoid trying to load the library when it's not built.
31+
from jax_ffi_example import gpu_examples # pylint: disable=g-import-not-at-top
32+
33+
self.read_state = gpu_examples.read_state
34+
35+
def test_basic(self):
36+
self.assertEqual(self.read_state(), 42)
37+
self.assertEqual(jax.jit(self.read_state)(), 42)
38+
39+
40+
if __name__ == "__main__":
41+
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)