Skip to content

Commit e72fac9

Browse files
committed
Synchronize stream in mha to prevent use-after-free of scratchpad memory.
1 parent 43dd029 commit e72fac9

File tree

6 files changed

+27
-25
lines changed

6 files changed

+27
-25
lines changed

.clang-format

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
---
22
Language: Cpp
33
ColumnLimit: 100
4-
BasedOnStyle: Google
4+
# BasedOnStyle: Google

csrc/flash_attn/check.h

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -39,32 +39,31 @@ class CheckHelper {
3939
std::ostringstream stream_;
4040
};
4141

42-
#define FFI_CHECK(expr) \
43-
static_assert(!std::is_same_v<decltype(expr), cudaError_t>, \
44-
"Use FFI_CUDA_CHECK for CUDA error codes, not FFI_CHECK."); \
45-
if (!(expr)) \
42+
#define FFI_CHECK(expr) \
43+
static_assert(!std::is_same_v<decltype(expr), cudaError_t>, \
44+
"Use FFI_CUDA_CHECK for CUDA error codes, not FFI_CHECK."); \
45+
if (!(expr)) \
4646
return CheckHelper(#expr)
4747

48-
#define FFI_CUDA_CHECK(expr) \
49-
static_assert(std::is_same_v<decltype(expr), cudaError_t>, \
50-
"Expect cudaError_t for FFI_CUDA_CHECK."); \
51-
if (cudaError_t _cuda_check = (expr); _cuda_check != cudaSuccess) \
52-
return CheckHelper(std::string(#expr)) \
53-
<< " CUDA Error: " << cudaGetErrorString(_cuda_check)
48+
#define FFI_CUDA_CHECK(expr) \
49+
static_assert(std::is_same_v<decltype(expr), cudaError_t>, \
50+
"Expect cudaError_t for FFI_CUDA_CHECK."); \
51+
if (cudaError_t _cuda_check = (expr); _cuda_check != cudaSuccess) \
52+
return CheckHelper(std::string(#expr)) << " CUDA Error: " << cudaGetErrorString(_cuda_check)
5453

55-
#define FFI_CHECK_OPTIONAL(dest, expr) \
56-
if (auto _opt = (expr); _opt.has_value()) \
57-
dest = _opt.value(); \
58-
else \
54+
#define FFI_CHECK_OPTIONAL(dest, expr) \
55+
if (auto _opt = (expr); _opt.has_value()) \
56+
dest = _opt.value(); \
57+
else \
5958
return CheckHelper(std::string(#expr))
6059

61-
#define FFI_RET_CHECK(expr) \
62-
if (auto _error = (expr); !_error.success()) \
60+
#define FFI_RET_CHECK(expr) \
61+
if (auto _error = (expr); !_error.success()) \
6362
return _error
6463

65-
#define FFI_CHECK_ALLOC(dest, expr) \
66-
void* dest = nullptr; \
67-
if (auto _opt = (expr); _opt.has_value()) \
68-
dest = _opt.value(); \
69-
else \
64+
#define FFI_CHECK_ALLOC(dest, expr) \
65+
void *dest = nullptr; \
66+
if (auto _opt = (expr); _opt.has_value()) \
67+
dest = _opt.value(); \
68+
else \
7069
return CheckHelper(std::string(#expr))

csrc/flash_attn/mha_bwd.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ ffi::Error mha_bwd_impl(cudaStream_t stream, ffi::ScratchAllocator scratch,
234234

235235
if (seqlen_q > 0) {
236236
launch(params, stream);
237+
FFI_CUDA_CHECK(cudaStreamSynchronize(stream));
237238
} else {
238239
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
239240
FFI_CUDA_CHECK(cudaMemset(dq->untyped_data(), 0, dq->size_bytes()));
@@ -414,6 +415,7 @@ mha_varlen_bwd_impl(
414415

415416
if (max_seqlen_q > 0) {
416417
launch(params, stream);
418+
FFI_CUDA_CHECK(cudaStreamSynchronize(stream));
417419
} else {
418420
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
419421
FFI_CUDA_CHECK(cudaMemsetAsync(dq->untyped_data(), 0, dq->size_bytes(), stream));

csrc/flash_attn/mha_fwd.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,7 @@ ffi::Error mha_fwd_impl(cudaStream_t stream, ffi::ScratchAllocator scratch, ffi:
136136

137137
if (seqlen_k > 0) {
138138
run_mha_fwd(params, stream);
139-
// C10_CUDA_CHECK(cudaStreamSynchronize(stream));
140-
// C10_CUDA_CHECK(cudaDeviceSynchronize());
139+
FFI_CUDA_CHECK(cudaStreamSynchronize(stream));
141140
} else {
142141
FFI_CHECK(false) << "seqlen_k is zero";
143142
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
@@ -330,6 +329,7 @@ mha_varlen_fwd_impl(
330329

331330
if (max_seqlen_k > 0) {
332331
run_mha_fwd(params, stream);
332+
FFI_CUDA_CHECK(cudaStreamSynchronize(stream));
333333
} else {
334334
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
335335
FFI_CUDA_CHECK(cudaMemsetAsync(out->untyped_data(), 0, out->size_bytes(), stream));

src/flash_attn_jax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = 'v0.4.0'
1+
__version__ = 'v0.4.1'
22
from .flash import flash_mha
33
from .varlen import flash_mha_varlen
44
__all__ = ['flash_mha', 'flash_mha_varlen']

tests/test_sharding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def check_sharding(sharding,q,k,v):
171171
out = flash((q,k,v))
172172
check(ref_out,ref16_out,out)
173173

174+
check_sharding(NamedSharding(mesh, P(None,None,None,None)), q, k, v)
174175
check_sharding(NamedSharding(mesh, P('x',None,None,None)), q, k, v)
175176
check_sharding(NamedSharding(mesh, P(None,None,'x',None)), q, k, v)
176177

0 commit comments

Comments
 (0)