7
7
#include < cuda_runtime_api.h>
8
8
#include < pybind11/pybind11.h>
9
9
10
- #include " flash.h"
11
- #include " exception.h"
12
- #include " static_switch.h"
13
10
#include " check.h"
14
11
15
- #include " flash_common.h"
16
12
#include " mha_fwd.h"
17
13
#include " mha_bwd.h"
14
+ #include " xla/ffi/api/c_api.h"
15
+ #include " xla/ffi/api/ffi.h"
16
+
17
+ namespace ffi = xla::ffi;
18
18
19
19
// std::vector<at::Tensor>
20
20
// mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
295
295
296
296
namespace {
297
297
298
- template <typename T> pybind11::capsule EncapsulateFunction (T *fn) {
299
- return pybind11::capsule (reinterpret_cast <void *>(fn), " xla._CUSTOM_CALL_TARGET" );
300
- }
301
-
302
298
template <typename T>
303
- inline std::string PackDescriptorAsString (const T& descriptor) {
304
- return std::string (reinterpret_cast <const char *>(&descriptor), sizeof (T));
305
- }
306
-
307
- template <typename T> pybind11::bytes PackDescriptor (const T &descriptor) {
308
- return pybind11::bytes (PackDescriptorAsString (descriptor));
299
+ pybind11::capsule EncapsulateFfiCall (T *fn) {
300
+ static_assert (std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>,
301
+ " Encapsulated function must be an XLA FFI handler" );
302
+ return pybind11::capsule (reinterpret_cast <void *>(fn));
309
303
}
310
304
311
- pybind11::bytes make_mha_fwd_args ( float p_dropout,
312
- float softmax_scale,
313
- bool is_causal,
314
- int window_size_left,
315
- int window_size_right,
316
- bool return_softmax,
317
- int n, int l, int h, int d,
318
- int l_k, int h_k,
319
- ElementType dtype,
320
- uint64_t seed) {
321
- return PackDescriptor (mha_fwd_args{p_dropout, softmax_scale, is_causal, window_size_left, window_size_right, return_softmax, n, l, h, d, l_k, h_k, dtype, seed});
322
- }
323
-
324
- pybind11::bytes make_mha_bwd_args ( float p_dropout,
325
- float softmax_scale,
326
- bool is_causal,
327
- int window_size_left,
328
- int window_size_right,
329
- bool deterministic,
330
- int n, int l, int h, int d,
331
- int l_k, int h_k,
332
- ElementType dtype,
333
- uint64_t seed) {
334
- return PackDescriptor (mha_bwd_args{p_dropout, softmax_scale, is_causal, window_size_left, window_size_right, deterministic, n, l, h, d, l_k, h_k, dtype, seed});
335
- }
336
-
337
- pybind11::dict Registrations () {
305
+ XLA_FFI_DEFINE_HANDLER (
306
+ mha_fwd, mha_fwd_impl,
307
+ ffi::Ffi::Bind ()
308
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
309
+ .Ctx<ffi::ScratchAllocator>()
310
+ .Arg<ffi::AnyBuffer>()
311
+ .Arg<ffi::AnyBuffer>()
312
+ .Arg<ffi::AnyBuffer>()
313
+ .Ret<ffi::AnyBuffer>()
314
+ .Ret<ffi::Buffer<ffi::F32>>()
315
+ .Attr<double>(" softmax_scale" )
316
+ .Attr<bool>(" is_causal" )
317
+ .Attr<int64_t>(" window_size_left" )
318
+ .Attr<int64_t>(" window_size_right" )
319
+ );
320
+
321
+ XLA_FFI_DEFINE_HANDLER (
322
+ mha_bwd, mha_bwd_impl,
323
+ ffi::Ffi::Bind ()
324
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
325
+ .Ctx<ffi::ScratchAllocator>()
326
+ .Arg<ffi::AnyBuffer>() // dout
327
+ .Arg<ffi::AnyBuffer>() // q
328
+ .Arg<ffi::AnyBuffer>() // k
329
+ .Arg<ffi::AnyBuffer>() // v
330
+ .Arg<ffi::AnyBuffer>() // o
331
+ .Arg<ffi::Buffer<ffi::F32>>() // lse
332
+ .Ret<ffi::AnyBuffer>() // dq
333
+ .Ret<ffi::AnyBuffer>() // dk
334
+ .Ret<ffi::AnyBuffer>() // dv
335
+ .Attr<double>(" softmax_scale" )
336
+ .Attr<bool>(" is_causal" )
337
+ .Attr<int64_t>(" window_size_left" )
338
+ .Attr<int64_t>(" window_size_right" )
339
+ );
340
+
341
+
342
+ pybind11::dict FFIRegistrations () {
338
343
pybind11::dict dict;
339
- dict[" flash_mha_fwd" ] = EncapsulateFunction (mha_fwd);
340
- dict[" flash_mha_bwd" ] = EncapsulateFunction (mha_bwd);
344
+ dict[" flash_mha_fwd" ] = EncapsulateFfiCall (mha_fwd);
345
+ dict[" flash_mha_bwd" ] = EncapsulateFfiCall (mha_bwd);
341
346
return dict;
342
347
}
343
348
344
349
345
350
PYBIND11_MODULE (flash_api, m) {
346
351
m.doc () = " FlashAttention" ;
347
- m.def (" get_registrations" , &Registrations);
348
- m.def (" make_flash_mha_fwd_args" , &make_mha_fwd_args);
349
- m.def (" make_flash_mha_bwd_args" , &make_mha_bwd_args);
350
- pybind11::enum_<ElementType>(m, " ElementType" )
351
- .value (" BF16" , BF16)
352
- .value (" FP16" , FP16)
353
- .export_values ();
352
+ m.def (" get_ffi_registrations" , &FFIRegistrations);
354
353
355
354
// m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
356
355
// m.def("bwd", &mha_bwd, "Backward pass");
357
356
// m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
358
357
// m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
359
358
}
360
359
361
- }
360
+ } // namespace
0 commit comments