Skip to content

Commit 8370082

Browse files
Merge pull request jax-ml#23805 from dfm:ffi-examples-state
PiperOrigin-RevId: 696383873
2 parents 268c86f + f086483 commit 8370082

File tree

5 files changed

+179
-9
lines changed

5 files changed

+179
-9
lines changed

examples/ffi/CMakeLists.txt

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,18 @@ message(STATUS "XLA include directory: ${XLA_DIR}")
1212

1313
find_package(nanobind CONFIG REQUIRED)
1414

15-
nanobind_add_module(_rms_norm NB_STATIC "src/jax_ffi_example/rms_norm.cc")
16-
target_include_directories(_rms_norm PUBLIC ${XLA_DIR})
17-
install(TARGETS _rms_norm LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
15+
set(
16+
JAX_FFI_EXAMPLE_PROJECTS
17+
"rms_norm"
18+
"attrs"
19+
"counter"
20+
)
1821

19-
nanobind_add_module(_attrs NB_STATIC "src/jax_ffi_example/attrs.cc")
20-
target_include_directories(_attrs PUBLIC ${XLA_DIR})
21-
install(TARGETS _attrs LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
22+
foreach(PROJECT ${JAX_FFI_EXAMPLE_PROJECTS})
23+
nanobind_add_module("_${PROJECT}" NB_STATIC "src/jax_ffi_example/${PROJECT}.cc")
24+
target_include_directories("_${PROJECT}" PUBLIC ${XLA_DIR})
25+
install(TARGETS "_${PROJECT}" LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
26+
endforeach()
2227

2328
if(JAX_FFI_EXAMPLE_ENABLE_CUDA)
2429
enable_language(CUDA)

examples/ffi/README.md

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,26 @@
33
This directory includes an example project demonstrating the use of JAX's
44
foreign function interface (FFI). The JAX docs provide more information about
55
this interface in [the FFI tutorial](https://jax.readthedocs.io/en/latest/ffi.html),
6-
but the example in this directory explicitly demonstrates:
6+
but the example in this directory complements that document by demonstrating
7+
(and testing!) the full packaging workflow, and some more advanced use cases.
8+
Within the example project, there are several example calls:
79

8-
1. One way to package and distribute FFI targets, and
9-
2. Some more advanced use cases.
10+
1. `rms_norm`: This is the example from the tutorial on the JAX docs, and it
11+
demonstrates the most basic use of the FFI. It also includes customization of
12+
behavior under automatic differentiation using `jax.custom_vjp`.
13+
14+
2. `counter`: This example demonstrates a common pattern for how an FFI call can
15+
use global cache to maintain state between calls. This pattern is useful when
16+
an FFI call requires an expensive initialization step which shouldn't be
17+
run on every execution, or if there is other shared state that could be
18+
reused between calls. In this simple example we just count the number of
19+
times the call was executed.
20+
21+
3. `attrs`: An example demonstrating the different ways that attributes can be
22+
passed to the FFI. For example, we can pass arrays, variadic attributes, and
23+
user-defined types. Full support of user-defined types isn't yet supported
24+
by XLA, so that example will be added in the future.
25+
26+
4. `cuda_e2e`: An end-to-end example demonstrating the use of the JAX FFI with
27+
CUDA. The specifics of the kernels are not very important, but the general
28+
structure, and packaging of the extension are useful for testing.
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/* Copyright 2024 The JAX Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <cstdint>
17+
#include <mutex>
18+
#include <string_view>
19+
#include <unordered_map>
20+
21+
#include "nanobind/nanobind.h"
22+
#include "xla/ffi/api/ffi.h"
23+
24+
namespace nb = nanobind;
25+
namespace ffi = xla::ffi;
26+
27+
ffi::Error CounterImpl(int64_t index, ffi::ResultBufferR0<ffi::S32> out) {
28+
static std::mutex mutex;
29+
static auto& cache = *new std::unordered_map<int64_t, int32_t>();
30+
{
31+
const std::lock_guard<std::mutex> lock(mutex);
32+
auto it = cache.find(index);
33+
if (it != cache.end()) {
34+
out->typed_data()[0] = ++it->second;
35+
} else {
36+
cache.insert({index, 0});
37+
out->typed_data()[0] = 0;
38+
}
39+
}
40+
return ffi::Error::Success();
41+
}
42+
43+
XLA_FFI_DEFINE_HANDLER_SYMBOL(
44+
Counter, CounterImpl,
45+
ffi::Ffi::Bind().Attr<int64_t>("index").Ret<ffi::BufferR0<ffi::S32>>());
46+
47+
NB_MODULE(_counter, m) {
48+
m.def("registrations", []() {
49+
nb::dict registrations;
50+
registrations["counter"] = nb::capsule(reinterpret_cast<void*>(Counter));
51+
return registrations;
52+
});
53+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""An example demonstrating how an FFI call can maintain "state" between calls
16+
17+
In this case, the ``counter`` call simply accumulates the number of times it
18+
was executed, but this pattern can also be used for more advanced use cases.
19+
For example, this pattern is used in jaxlib for:
20+
21+
1. The GPU solver linear algebra kernels which require an expensive "handler"
22+
initialization, and
23+
2. The ``triton_call`` function which caches the compiled triton modules after
24+
their first use.
25+
"""
26+
27+
import jax
28+
import jax.extend as jex
29+
30+
from jax_ffi_example import _counter
31+
32+
for name, target in _counter.registrations().items():
33+
jex.ffi.register_ffi_target(name, target)
34+
35+
36+
def counter(index):
37+
return jex.ffi.ffi_call(
38+
"counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index))

examples/ffi/tests/counter_test.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from absl.testing import absltest
16+
17+
import jax
18+
from jax._src import test_util as jtu
19+
20+
from jax_ffi_example import counter
21+
22+
jax.config.parse_flags_with_absl()
23+
24+
25+
class CounterTests(jtu.JaxTestCase):
26+
def setUp(self):
27+
super().setUp()
28+
if not jtu.test_device_matches(["cpu"]):
29+
self.skipTest("Unsupported platform")
30+
31+
def test_basic(self):
32+
self.assertEqual(counter.counter(0), 0)
33+
self.assertEqual(counter.counter(0), 1)
34+
self.assertEqual(counter.counter(0), 2)
35+
self.assertEqual(counter.counter(1), 0)
36+
self.assertEqual(counter.counter(0), 3)
37+
38+
def test_jit(self):
39+
@jax.jit
40+
def counter_fun(x):
41+
return x, counter.counter(2)
42+
43+
self.assertEqual(counter_fun(0)[1], 0)
44+
self.assertEqual(counter_fun(0)[1], 1)
45+
46+
# Persists across different cache hits
47+
self.assertEqual(counter_fun(1)[1], 2)
48+
49+
# Persists after the cache is cleared
50+
counter_fun.clear_cache()
51+
self.assertEqual(counter_fun(0)[1], 3)
52+
53+
54+
if __name__ == "__main__":
55+
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)