-
|
Hi! I am currently trying to extend the JAX Python package with Rust bindings. My goal is to adapt the extending JAX with C++ tutorial to Rust, if possible, using PyO3 and Maturin. Unfortunately, generating XLA-compatible code heavily relies on C++ features like macros and templates, for which it seems impossible to automatically generate bindings (through
I could probably stop at step (3) and use, e.g., nanobind to generate the Python module from C++, but my goal is to preferably stick with PyO3 and Maturin, as they are much more convenient in my case :-) The Rust code looks as follows: fn rms_norm(eps: f32, x: &[f32], y: &mut [f32]) {
/* actual implementation */
}
#[cxx::bridge]
mod ffi {
extern "Rust" {
// Expose to C++ our Rust function
fn rms_norm(eps: f32, x: &[f32], y: &mut [f32]) -> ();
}
unsafe extern "C++" {
include!("rms-norm/include/ffi.h");
type XLA_FFI_Error;
type XLA_FFI_CallFrame;
// This is the C++ XLA compatible wrapper around our 'rms_norm' Rust function
unsafe fn RmsNorm(call_frame: *mut XLA_FFI_CallFrame) -> *mut XLA_FFI_Error;
}
}
#[pymodule]
fn _rms_norm(m: &Bound<'_, PyModule>) -> PyResult<()> {
let name = CString::new("rms_norm").unwrap();
let f: unsafe fn(*mut ffi::XLA_FFI_CallFrame) -> *mut ffi::XLA_FFI_Error = ffi::RmsNorm;
m.add("rms_norm", PyCapsule::new(m.py(), f, Some(name))?)?;
Ok(())
}When registering the custom FFI call (inside Python), with: import rms_norm._rms_norm as rms_norm_lib
jex.ffi.register_ffi_target("rms_norm", rms_norm_lib.rms_norm, platform="cpu")it somewhat later generates a segmentation fault, so my question is: do I pass the pointer to the C++ function correctly? When reading the C++ example, they register the callable1 as a Thanks for your help! You can find the full MWE code here: https://github.com/jeertmans/extending-jax/tree/5297b806c8f434030612875e270e3f598ec0e38d Footnotes
|
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
|
Hi @davidhewitt (sorry to ping you specifically, I chose the first maintainer on the list), this is a kind UP as I didn't get reply on this yet :-) |
Beta Was this translation helpful? Give feedback.
-
|
xrefs: |
Beta Was this translation helpful? Give feedback.
xrefs: