Skip to content

Commit d7d0ff2

Browse files
committed
Refactor frontend to use the new jax ffi api, which much improves error handling and allows using xla's allocator for scratchpad arrays.
1 parent 6d55af7 commit d7d0ff2

File tree

13 files changed

+4675
-562
lines changed

13 files changed

+4675
-562
lines changed

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE)
1111
# == Find dependencies ==
1212
find_package(Python REQUIRED COMPONENTS Interpreter Development.Module)
1313

14+
message(STATUS "Python executable: ${Python_EXECUTABLE}")
15+
1416
execute_process(
1517
COMMAND ${Python_EXECUTABLE} -m pybind11 --cmakedir
1618
OUTPUT_VARIABLE pybind11_DIR
@@ -88,6 +90,8 @@ target_include_directories(flash_api PRIVATE
8890
${CMAKE_CURRENT_SOURCE_DIR}/csrc/flash_attn
8991
${CMAKE_CURRENT_SOURCE_DIR}/csrc/flash_attn/src
9092
${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass/include
93+
${CMAKE_CURRENT_SOURCE_DIR}/csrc
94+
9195
)
9296

9397
target_link_libraries(flash_api PRIVATE

csrc/flash_attn/check.h

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,65 @@
11
#pragma once
22

3-
#include <stdio.h>
4-
5-
inline void check_implementation(bool expr, std::string check_message) {
6-
if (!expr) {
7-
fprintf(stderr, "%s\n", check_message.c_str());
8-
abort();
9-
}
10-
}
11-
12-
#define CHECK(EXPR, MESSAGE) \
13-
do { \
14-
const bool __err = EXPR; \
15-
check_implementation( \
16-
__err, \
17-
MESSAGE); \
18-
} while (0)
3+
#include <cstdio>
4+
#include <cuda_runtime_api.h>
5+
#include <driver_types.h>
6+
#include <string>
7+
8+
#include "xla/ffi/api/ffi.h"
9+
10+
namespace ffi = xla::ffi;
11+
12+
class CheckHelper {
13+
public:
14+
explicit CheckHelper(std::string expr) : expr_(expr) {}
15+
16+
template <typename T> inline CheckHelper &operator<<(const T &value) {
17+
fprintf(stderr, "debug: adding value %s\n", value);
18+
stream_ << value;
19+
return *this;
20+
}
21+
22+
inline CheckHelper &operator<<(ffi::ErrorCode errc) {
23+
errc_ = errc;
24+
return *this;
25+
}
26+
27+
inline operator ffi::Error() {
28+
std::ostringstream full_message;
29+
full_message << "Check failed: " << expr_;
30+
std::string additional = stream_.str();
31+
if (!additional.empty()) {
32+
fprintf(stderr, "debug: %s\n", additional.c_str());
33+
full_message << "; " << additional;
34+
}
35+
return ffi::Error(errc_, full_message.str());
36+
}
37+
38+
private:
39+
ffi::ErrorCode errc_ = ffi::ErrorCode::kUnknown;
40+
std::string expr_;
41+
std::ostringstream stream_;
42+
};
43+
44+
#define FFI_CHECK(expr) \
45+
static_assert(!std::is_same_v<decltype(expr), cudaError_t>, \
46+
"Use FFI_CUDA_CHECK for CUDA error codes, not FFI_CHECK."); \
47+
if (!(expr)) \
48+
return CheckHelper(#expr)
49+
50+
#define FFI_CUDA_CHECK(expr) \
51+
static_assert(std::is_same_v<decltype(expr), cudaError_t>, \
52+
"Expect cudaError_t for FFI_CUDA_CHECK."); \
53+
if (cudaError_t _cuda_check = (expr); _cuda_check != cudaSuccess) \
54+
return CheckHelper(std::string(#expr)) \
55+
<< " CUDA Error: " << cudaGetErrorString(_cuda_check)
56+
57+
#define FFI_CHECK_OPTIONAL(dest, expr) \
58+
if (auto _opt = (expr); _opt.has_value()) \
59+
dest = _opt.value(); \
60+
else \
61+
return CheckHelper(std::string(#expr))
62+
63+
#define FFI_RET_CHECK(expr) \
64+
if (auto _error = (expr); !_error.success()) \
65+
return _error

csrc/flash_attn/flash_api.cpp

Lines changed: 50 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
#include <cuda_runtime_api.h>
88
#include <pybind11/pybind11.h>
99

10-
#include "flash.h"
11-
#include "exception.h"
12-
#include "static_switch.h"
1310
#include "check.h"
1411

15-
#include "flash_common.h"
1612
#include "mha_fwd.h"
1713
#include "mha_bwd.h"
14+
#include "xla/ffi/api/c_api.h"
15+
#include "xla/ffi/api/ffi.h"
16+
17+
namespace ffi = xla::ffi;
1818

1919
// std::vector<at::Tensor>
2020
// mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
@@ -295,67 +295,66 @@
295295

296296
namespace {
297297

298-
template <typename T> pybind11::capsule EncapsulateFunction(T *fn) {
299-
return pybind11::capsule(reinterpret_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
300-
}
301-
302298
template <typename T>
303-
inline std::string PackDescriptorAsString(const T& descriptor) {
304-
return std::string(reinterpret_cast<const char*>(&descriptor), sizeof(T));
305-
}
306-
307-
template <typename T> pybind11::bytes PackDescriptor(const T &descriptor) {
308-
return pybind11::bytes(PackDescriptorAsString(descriptor));
299+
pybind11::capsule EncapsulateFfiCall(T *fn) {
300+
static_assert(std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>,
301+
"Encapsulated function must be an XLA FFI handler");
302+
return pybind11::capsule(reinterpret_cast<void *>(fn));
309303
}
310304

311-
pybind11::bytes make_mha_fwd_args( float p_dropout,
312-
float softmax_scale,
313-
bool is_causal,
314-
int window_size_left,
315-
int window_size_right,
316-
bool return_softmax,
317-
int n, int l, int h, int d,
318-
int l_k, int h_k,
319-
ElementType dtype,
320-
uint64_t seed) {
321-
return PackDescriptor(mha_fwd_args{p_dropout, softmax_scale, is_causal, window_size_left, window_size_right, return_softmax, n, l, h, d, l_k, h_k, dtype, seed});
322-
}
323-
324-
pybind11::bytes make_mha_bwd_args( float p_dropout,
325-
float softmax_scale,
326-
bool is_causal,
327-
int window_size_left,
328-
int window_size_right,
329-
bool deterministic,
330-
int n, int l, int h, int d,
331-
int l_k, int h_k,
332-
ElementType dtype,
333-
uint64_t seed) {
334-
return PackDescriptor(mha_bwd_args{p_dropout, softmax_scale, is_causal, window_size_left, window_size_right, deterministic, n, l, h, d, l_k, h_k, dtype, seed});
335-
}
336-
337-
pybind11::dict Registrations() {
305+
XLA_FFI_DEFINE_HANDLER(
306+
mha_fwd, mha_fwd_impl,
307+
ffi::Ffi::Bind()
308+
.Ctx<ffi::PlatformStream<cudaStream_t>>()
309+
.Ctx<ffi::ScratchAllocator>()
310+
.Arg<ffi::AnyBuffer>()
311+
.Arg<ffi::AnyBuffer>()
312+
.Arg<ffi::AnyBuffer>()
313+
.Ret<ffi::AnyBuffer>()
314+
.Ret<ffi::Buffer<ffi::F32>>()
315+
.Attr<double>("softmax_scale")
316+
.Attr<bool>("is_causal")
317+
.Attr<int64_t>("window_size_left")
318+
.Attr<int64_t>("window_size_right")
319+
);
320+
321+
XLA_FFI_DEFINE_HANDLER(
322+
mha_bwd, mha_bwd_impl,
323+
ffi::Ffi::Bind()
324+
.Ctx<ffi::PlatformStream<cudaStream_t>>()
325+
.Ctx<ffi::ScratchAllocator>()
326+
.Arg<ffi::AnyBuffer>() // dout
327+
.Arg<ffi::AnyBuffer>() // q
328+
.Arg<ffi::AnyBuffer>() // k
329+
.Arg<ffi::AnyBuffer>() // v
330+
.Arg<ffi::AnyBuffer>() // o
331+
.Arg<ffi::Buffer<ffi::F32>>() // lse
332+
.Ret<ffi::AnyBuffer>() // dq
333+
.Ret<ffi::AnyBuffer>() // dk
334+
.Ret<ffi::AnyBuffer>() // dv
335+
.Attr<double>("softmax_scale")
336+
.Attr<bool>("is_causal")
337+
.Attr<int64_t>("window_size_left")
338+
.Attr<int64_t>("window_size_right")
339+
);
340+
341+
342+
pybind11::dict FFIRegistrations() {
338343
pybind11::dict dict;
339-
dict["flash_mha_fwd"] = EncapsulateFunction(mha_fwd);
340-
dict["flash_mha_bwd"] = EncapsulateFunction(mha_bwd);
344+
dict["flash_mha_fwd"] = EncapsulateFfiCall(mha_fwd);
345+
dict["flash_mha_bwd"] = EncapsulateFfiCall(mha_bwd);
341346
return dict;
342347
}
343348

344349

345350
PYBIND11_MODULE(flash_api, m) {
346351
m.doc() = "FlashAttention";
347-
m.def("get_registrations", &Registrations);
348-
m.def("make_flash_mha_fwd_args", &make_mha_fwd_args);
349-
m.def("make_flash_mha_bwd_args", &make_mha_bwd_args);
350-
pybind11::enum_<ElementType>(m, "ElementType")
351-
.value("BF16", BF16)
352-
.value("FP16", FP16)
353-
.export_values();
352+
m.def("get_ffi_registrations", &FFIRegistrations);
354353

355354
// m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
356355
// m.def("bwd", &mha_bwd, "Backward pass");
357356
// m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
358357
// m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
359358
}
360359

361-
}
360+
} // namespace

csrc/flash_attn/flash_common.cpp

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
#include <pybind11/pybind11.h>
66

77
#include "flash.h"
8-
#include "exception.h"
9-
#include "static_switch.h"
108
#include "check.h"
119
#include "flash_common.h"
10+
#include "xla/ffi/api/ffi.h"
1211

13-
void set_params_fprop(Flash_fwd_params &params,
14-
ElementType element_type,
12+
namespace ffi = xla::ffi;
13+
14+
ffi::Error set_params_fprop(Flash_fwd_params &params,
15+
ffi::DataType element_type,
1516
// sizes
1617
const size_t b,
1718
const size_t seqlen_q,
@@ -41,7 +42,7 @@ void set_params_fprop(Flash_fwd_params &params,
4142
// Reset the parameters
4243
memset(&params, 0, sizeof(params));
4344

44-
params.is_bf16 = element_type == BF16;
45+
params.is_bf16 = element_type == ffi::DataType::BF16;
4546

4647
// Set the pointers and strides.
4748
params.q_ptr = q_ptr;
@@ -110,7 +111,7 @@ void set_params_fprop(Flash_fwd_params &params,
110111
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
111112
params.rp_dropout = 1.f / params.p_dropout;
112113
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
113-
CHECK(p_dropout < 1.f, "dropout must be <1");
114+
FFI_CHECK(p_dropout < 1.f) << "dropout must be <1";
114115

115116
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
116117
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
@@ -122,6 +123,8 @@ void set_params_fprop(Flash_fwd_params &params,
122123
params.window_size_right = window_size_right;
123124

124125
params.is_seqlens_k_cumulative = true;
126+
127+
return ffi::Error(); // Success
125128
}
126129

127130
// Find the number of splits that maximizes the occupancy. For example, if we have
@@ -166,10 +169,10 @@ int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks
166169
return 1;
167170
}
168171

169-
void set_params_splitkv(Flash_fwd_params &params, const int batch_size,
172+
ffi::Error set_params_splitkv(ffi::ScratchAllocator* scratch, Flash_fwd_params& params, const int batch_size,
170173
const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
171174
const int head_size_rounded, const float p_dropout,
172-
const int num_splits, int multiProcessorCount, ElementType dtype) {
175+
const int num_splits, int multiProcessorCount, ffi::DataType dtype) {
173176
// This needs to match with run_mha_fwd_splitkv_dispatch
174177
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
175178
const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
@@ -184,13 +187,14 @@ void set_params_splitkv(Flash_fwd_params &params, const int batch_size,
184187
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, multiProcessorCount, num_n_blocks, 128);
185188
}
186189
if (params.num_splits > 1) {
187-
// at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
188-
// at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
189-
C10_CUDA_CHECK(cudaMalloc((void**)&params.softmax_lseaccum_ptr, params.num_splits * batch_size * num_heads * max_seqlen_q * 4)); // float32
190-
C10_CUDA_CHECK(cudaMalloc((void**)&params.oaccum_ptr, params.num_splits * batch_size * num_heads * max_seqlen_q * head_size_rounded * 4));
191-
// params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
192-
// params.oaccum_ptr = out_accum.data_ptr();
190+
FFI_CHECK_OPTIONAL(*(void**)&params.softmax_lseaccum_ptr, scratch->Allocate(
191+
params.num_splits * batch_size * num_heads * max_seqlen_q * 4, 4))
192+
<< "Failed to allocate memory for softmax_lseaccum";
193+
FFI_CHECK_OPTIONAL(*(void**)&params.oaccum_ptr, scratch->Allocate(
194+
params.num_splits * batch_size * num_heads * max_seqlen_q * head_size_rounded * 4, 4))
195+
<< "Failed to allocate memory for oaccum";
193196
}
194-
CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
197+
FFI_CHECK(params.num_splits <= 128) << "num_splits > 128 not supported - " << params.num_splits;
195198
}
199+
return ffi::Error();
196200
}

csrc/flash_attn/flash_common.h

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@
66
#include <pybind11/pybind11.h>
77

88
#include "flash.h"
9-
#include "exception.h"
10-
#include "static_switch.h"
119
#include "check.h"
10+
#include "xla/ffi/api/ffi.h"
1211

13-
enum ElementType { BF16, FP16, FP32 };
12+
namespace ffi = xla::ffi;
1413

15-
void set_params_fprop(Flash_fwd_params &params,
16-
ElementType element_type,
14+
ffi::Error set_params_fprop(Flash_fwd_params &params,
15+
ffi::DataType element_type,
1716
// sizes
1817
const size_t b,
1918
const size_t seqlen_q,
@@ -40,20 +39,7 @@ void set_params_fprop(Flash_fwd_params &params,
4039
int window_size_right,
4140
bool seqlenq_ngroups_swapped=false);
4241

43-
void set_params_splitkv(Flash_fwd_params &params, const int batch_size,
42+
ffi::Error set_params_splitkv(ffi::ScratchAllocator* scratch, Flash_fwd_params& params, const int batch_size,
4443
const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
4544
const int head_size_rounded, const float p_dropout,
46-
const int num_splits, int multiProcessorCount, ElementType dtype);
47-
48-
template <typename T>
49-
inline std::string Pack(const T& args) {
50-
return std::string(reinterpret_cast<const char*>(&args), sizeof(T));
51-
}
52-
53-
template <typename T>
54-
inline T Unpack(const void* opaque, size_t opaque_len) {
55-
T out;
56-
CHECK(sizeof(out)==opaque_len, "opaque len");
57-
memcpy(&out, opaque, opaque_len);
58-
return out;
59-
}
45+
const int num_splits, int multiProcessorCount, ffi::DataType dtype);

0 commit comments

Comments
 (0)