Skip to content

Commit 1ca96c7

Browse files
committed
Add varlen attention support (no vmap or sharding support yet).
1 parent d7d0ff2 commit 1ca96c7

File tree

18 files changed

+1037
-499
lines changed

18 files changed

+1037
-499
lines changed

.clang-format

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
Language: Cpp
3+
ColumnLimit: 100
4+
BasedOnStyle: Google

CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ string(REGEX REPLACE "--generate-code=arch=compute_[0-9]+,code=\\[?compute_[0-9]
2727
string(REGEX REPLACE "-gencode arch=compute_[0-9]+,code=sm_[0-9]+" ""
2828
CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
2929

30-
message(WARNING "CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
30+
message(STATUS "CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
3131

3232
# Set up ccache
3333
find_program(CCACHE_PROGRAM ccache)
@@ -49,9 +49,11 @@ find_package(CUDAToolkit REQUIRED)
4949

5050

5151
# CUDA flags
52+
set(CMAKE_CXX_STANDARD 20)
53+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
54+
set(CMAKE_CXX_EXTENSIONS OFF)
5255
set(CUDA_FLAGS
5356
-O3
54-
-std=c++20
5557
--use_fast_math
5658
--expt-relaxed-constexpr
5759
--expt-extended-lambda

csrc/flash_attn/check.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ class CheckHelper {
1414
explicit CheckHelper(std::string expr) : expr_(expr) {}
1515

1616
template <typename T> inline CheckHelper &operator<<(const T &value) {
17-
fprintf(stderr, "debug: adding value %s\n", value);
1817
stream_ << value;
1918
return *this;
2019
}
@@ -29,7 +28,6 @@ class CheckHelper {
2928
full_message << "Check failed: " << expr_;
3029
std::string additional = stream_.str();
3130
if (!additional.empty()) {
32-
fprintf(stderr, "debug: %s\n", additional.c_str());
3331
full_message << "; " << additional;
3432
}
3533
return ffi::Error(errc_, full_message.str());
@@ -63,3 +61,10 @@ class CheckHelper {
6361
#define FFI_RET_CHECK(expr) \
6462
if (auto _error = (expr); !_error.success()) \
6563
return _error
64+
65+
#define FFI_CHECK_ALLOC(dest, expr) \
66+
void* dest = nullptr; \
67+
if (auto _opt = (expr); _opt.has_value()) \
68+
dest = _opt.value(); \
69+
else \
70+
return CheckHelper(std::string(#expr))

csrc/flash_attn/flash_api.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,11 +338,61 @@ XLA_FFI_DEFINE_HANDLER(
338338
.Attr<int64_t>("window_size_right")
339339
);
340340

341+
XLA_FFI_DEFINE_HANDLER(
342+
mha_varlen_fwd, mha_varlen_fwd_impl,
343+
ffi::Ffi::Bind()
344+
.Ctx<ffi::PlatformStream<cudaStream_t>>()
345+
.Ctx<ffi::ScratchAllocator>()
346+
.Arg<ffi::AnyBuffer>() // q
347+
.Arg<ffi::AnyBuffer>() // k
348+
.Arg<ffi::AnyBuffer>() // v
349+
.Arg<ffi::Buffer<ffi::S32>>() // cu_seqlens_q
350+
.Arg<ffi::Buffer<ffi::S32>>() // cu_seqlens_k
351+
.Arg<ffi::Buffer<ffi::S32>>() // seqused_k
352+
.Ret<ffi::AnyBuffer>() // o
353+
.Ret<ffi::Buffer<ffi::F32>>() // lse
354+
.Attr<int>("max_seqlen_q")
355+
.Attr<int>("max_seqlen_k")
356+
.Attr<bool>("has_seqused_k")
357+
.Attr<double>("softmax_scale")
358+
.Attr<bool>("zero_tensors")
359+
.Attr<bool>("is_causal")
360+
.Attr<int64_t>("window_size_left")
361+
.Attr<int64_t>("window_size_right")
362+
);
363+
364+
XLA_FFI_DEFINE_HANDLER(
365+
mha_varlen_bwd, mha_varlen_bwd_impl,
366+
ffi::Ffi::Bind()
367+
.Ctx<ffi::PlatformStream<cudaStream_t>>()
368+
.Ctx<ffi::ScratchAllocator>()
369+
.Arg<ffi::AnyBuffer>() // dout
370+
.Arg<ffi::AnyBuffer>() // q
371+
.Arg<ffi::AnyBuffer>() // k
372+
.Arg<ffi::AnyBuffer>() // v
373+
.Arg<ffi::AnyBuffer>() // o
374+
.Arg<ffi::Buffer<ffi::F32>>() // lse
375+
.Arg<ffi::Buffer<ffi::S32>>() // cu_seqlens_q
376+
.Arg<ffi::Buffer<ffi::S32>>() // cu_seqlens_k
377+
.Ret<ffi::AnyBuffer>() // dq
378+
.Ret<ffi::AnyBuffer>() // dk
379+
.Ret<ffi::AnyBuffer>() // dv
380+
.Attr<int64_t>("max_seqlen_q")
381+
.Attr<int64_t>("max_seqlen_k")
382+
.Attr<float>("softmax_scale")
383+
.Attr<bool>("zero_tensors")
384+
.Attr<bool>("is_causal")
385+
.Attr<int64_t>("window_size_left")
386+
.Attr<int64_t>("window_size_right")
387+
.Attr<bool>("deterministic")
388+
);
341389

342390
pybind11::dict FFIRegistrations() {
343391
pybind11::dict dict;
344392
dict["flash_mha_fwd"] = EncapsulateFfiCall(mha_fwd);
345393
dict["flash_mha_bwd"] = EncapsulateFfiCall(mha_bwd);
394+
dict["flash_mha_varlen_fwd"] = EncapsulateFfiCall(mha_varlen_fwd);
395+
dict["flash_mha_varlen_bwd"] = EncapsulateFfiCall(mha_varlen_bwd);
346396
return dict;
347397
}
348398

0 commit comments

Comments
 (0)