Skip to content

Commit 788f493

Browse files
Merge pull request jax-ml#25041 from dfm:ffi-example-refactor
PiperOrigin-RevId: 700093685
2 parents 6761512 + 84a9cba commit 788f493

File tree

11 files changed

+122
-189
lines changed

11 files changed

+122
-189
lines changed

examples/ffi/CMakeLists.txt

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ find_package(nanobind CONFIG REQUIRED)
1515
set(
1616
JAX_FFI_EXAMPLE_PROJECTS
1717
"rms_norm"
18-
"attrs"
19-
"counter"
18+
"cpu_examples"
2019
)
2120

2221
foreach(PROJECT ${JAX_FFI_EXAMPLE_PROJECTS})
@@ -27,9 +26,9 @@ endforeach()
2726

2827
if(JAX_FFI_EXAMPLE_ENABLE_CUDA)
2928
enable_language(CUDA)
30-
add_library(_cuda_e2e SHARED "src/jax_ffi_example/cuda_e2e.cu")
31-
set_target_properties(_cuda_e2e PROPERTIES POSITION_INDEPENDENT_CODE ON
32-
CUDA_STANDARD 17)
33-
target_include_directories(_cuda_e2e PUBLIC ${XLA_DIR})
34-
install(TARGETS _cuda_e2e LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
29+
add_library(_cuda_examples SHARED "src/jax_ffi_example/cuda_examples.cu")
30+
set_target_properties(_cuda_examples PROPERTIES POSITION_INDEPENDENT_CODE ON
31+
CUDA_STANDARD 17)
32+
target_include_directories(_cuda_examples PUBLIC ${XLA_DIR})
33+
install(TARGETS _cuda_examples LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
3534
endif()

examples/ffi/README.md

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,19 @@ Within the example project, there are several example calls:
1111
demonstrates the most basic use of the FFI. It also includes customization of
1212
behavior under automatic differentiation using `jax.custom_vjp`.
1313

14-
2. `counter`: This example demonstrates a common pattern for how an FFI call can
15-
use global cache to maintain state between calls. This pattern is useful when
16-
an FFI call requires an expensive initialization step which shouldn't be
17-
run on every execution, or if there is other shared state that could be
18-
reused between calls. In this simple example we just count the number of
19-
times the call was executed.
14+
2. `cpu_examples`: This submodule includes several smaller examples:
2015

21-
3. `attrs`: An example demonstrating the different ways that attributes can be
22-
passed to the FFI. For example, we can pass arrays, variadic attributes, and
23-
user-defined types. Full support of user-defined types isn't yet supported
24-
by XLA, so that example will be added in the future.
16+
* `counter`: This example demonstrates a common pattern for how an FFI call
17+
can use global cache to maintain state between calls. This pattern is
18+
useful when an FFI call requires an expensive initialization step which
19+
shouldn't be run on every execution, or if there is other shared state
20+
that could be reused between calls. In this simple example we just count
21+
the number of times the call was executed.
22+
* `attrs`: An example demonstrating the different ways that attributes can be
23+
passed to the FFI. For example, we can pass arrays, variadic attributes,
24+
and user-defined types. Full support of user-defined types isn't yet
25+
supported by XLA, so that example will be added in the future.
2526

26-
4. `cuda_e2e`: An end-to-end example demonstrating the use of the JAX FFI with
27-
CUDA. The specifics of the kernels are not very important, but the general
28-
structure, and packaging of the extension are useful for testing.
27+
3. `cuda_examples`: An end-to-end example demonstrating the use of the JAX FFI
28+
with CUDA. The specifics of the kernels are not very important, but the
29+
general structure, and packaging of the extension are useful for testing.

examples/ffi/src/jax_ffi_example/counter.cc

Lines changed: 0 additions & 53 deletions
This file was deleted.

examples/ffi/src/jax_ffi_example/counter.py

Lines changed: 0 additions & 38 deletions
This file was deleted.

examples/ffi/src/jax_ffi_example/attrs.cc renamed to examples/ffi/src/jax_ffi_example/cpu_examples.cc

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,27 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#include <cstdint>
17+
#include <mutex>
18+
#include <string_view>
19+
#include <unordered_map>
1720

1821
#include "nanobind/nanobind.h"
1922
#include "xla/ffi/api/ffi.h"
2023

2124
namespace nb = nanobind;
2225
namespace ffi = xla::ffi;
2326

27+
// ----------
28+
// Attributes
29+
// ----------
30+
//
31+
// An example demonstrating the different ways that attributes can be passed to
32+
// the FFI.
33+
//
34+
// For example, we can pass arrays, variadic attributes, and user-defined types.
35+
// Full support of user-defined types isn't yet supported by XLA, so that
36+
// example will be added in the future.
37+
2438
ffi::Error ArrayAttrImpl(ffi::Span<const int32_t> array,
2539
ffi::ResultBufferR0<ffi::S32> res) {
2640
int64_t total = 0;
@@ -54,13 +68,52 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DictionaryAttr, DictionaryAttrImpl,
5468
.Ret<ffi::BufferR0<ffi::S32>>()
5569
.Ret<ffi::BufferR0<ffi::S32>>());
5670

57-
NB_MODULE(_attrs, m) {
71+
// -------
72+
// Counter
73+
// -------
74+
//
75+
// An example demonstrating how an FFI call can maintain "state" between calls
76+
//
77+
// In this case, the ``Counter`` call simply accumulates the number of times it
78+
// was executed, but this pattern can also be used for more advanced use cases.
79+
// For example, this pattern is used in jaxlib for:
80+
//
81+
// 1. The GPU solver linear algebra kernels which require an expensive "handler"
82+
// initialization, and
83+
// 2. The ``triton_call`` function which caches the compiled triton modules
84+
// after their first use.
85+
86+
ffi::Error CounterImpl(int64_t index, ffi::ResultBufferR0<ffi::S32> out) {
87+
static std::mutex mutex;
88+
static auto &cache = *new std::unordered_map<int64_t, int32_t>();
89+
{
90+
const std::lock_guard<std::mutex> lock(mutex);
91+
auto it = cache.find(index);
92+
if (it != cache.end()) {
93+
out->typed_data()[0] = ++it->second;
94+
} else {
95+
cache.insert({index, 0});
96+
out->typed_data()[0] = 0;
97+
}
98+
}
99+
return ffi::Error::Success();
100+
}
101+
102+
XLA_FFI_DEFINE_HANDLER_SYMBOL(
103+
Counter, CounterImpl,
104+
ffi::Ffi::Bind().Attr<int64_t>("index").Ret<ffi::BufferR0<ffi::S32>>());
105+
106+
// Boilerplate for exposing handlers to Python
107+
NB_MODULE(_cpu_examples, m) {
58108
m.def("registrations", []() {
59109
nb::dict registrations;
60110
registrations["array_attr"] =
61111
nb::capsule(reinterpret_cast<void *>(ArrayAttr));
62112
registrations["dictionary_attr"] =
63113
nb::capsule(reinterpret_cast<void *>(DictionaryAttr));
114+
115+
registrations["counter"] = nb::capsule(reinterpret_cast<void *>(Counter));
116+
64117
return registrations;
65118
});
66119
}

examples/ffi/src/jax_ffi_example/attrs.py renamed to examples/ffi/src/jax_ffi_example/cpu_examples.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""An example demonstrating the different ways that attributes can be passed to
16-
the FFI.
17-
18-
For example, we can pass arrays, variadic attributes, and user-defined types.
19-
Full support of user-defined types isn't yet supported by XLA, so that example
20-
will be added in the future.
21-
"""
22-
2315
import numpy as np
2416

2517
import jax
2618
import jax.extend as jex
2719

28-
from jax_ffi_example import _attrs
20+
from jax_ffi_example import _cpu_examples
2921

30-
for name, target in _attrs.registrations().items():
22+
for name, target in _cpu_examples.registrations().items():
3123
jex.ffi.register_ffi_target(name, target)
3224

3325

@@ -43,3 +35,8 @@ def dictionary_attr(**kwargs):
4335
"dictionary_attr",
4436
(jax.ShapeDtypeStruct((), np.int32), jax.ShapeDtypeStruct((), np.int32)),
4537
)(**kwargs)
38+
39+
40+
def counter(index):
41+
return jex.ffi.ffi_call(
42+
"counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index))

examples/ffi/src/jax_ffi_example/cuda_e2e.py renamed to examples/ffi/src/jax_ffi_example/cuda_examples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import jax.extend as jex
2828

2929
# Load the shared library with the FFI target definitions
30-
SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "lib_cuda_e2e.so")
30+
SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "lib_cuda_examples.so")
3131
library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY)
3232

3333
jex.ffi.register_ffi_target("foo-fwd", jex.ffi.pycapsule(library.FooFwd),

examples/ffi/tests/counter_test.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

0 commit comments

Comments
 (0)