Skip to content

Commit f1e9755

Browse files
committed
Use FFI api to access device ids, this allows us to reduce cuda requirements to 12.3.
1 parent e72fac9 commit f1e9755

File tree

7 files changed

+42
-20
lines changed

7 files changed

+42
-20
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Please cite (see below) and credit FlashAttention if you use it.
99
## Installation
1010

1111
Requirements:
12-
- CUDA 12.8 and above.
12+
- CUDA 12.3 and above.
1313
- Linux. Same story as with the pytorch repo. I haven't tested compilation of the jax bindings on windows.
1414
- JAX >= `0.5.*`. The custom call api changed in this version.
1515

csrc/flash_attn/check.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,4 @@ class CheckHelper {
6666
if (auto _opt = (expr); _opt.has_value()) \
6767
dest = _opt.value(); \
6868
else \
69-
return CheckHelper(std::string(#expr))
69+
return CheckHelper(std::string(#expr))

csrc/flash_attn/flash_api.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ XLA_FFI_DEFINE_HANDLER(
307307
ffi::Ffi::Bind()
308308
.Ctx<ffi::PlatformStream<cudaStream_t>>()
309309
.Ctx<ffi::ScratchAllocator>()
310+
.Ctx<ffi::DeviceOrdinal>()
310311
.Arg<ffi::AnyBuffer>()
311312
.Arg<ffi::AnyBuffer>()
312313
.Arg<ffi::AnyBuffer>()
@@ -323,6 +324,7 @@ XLA_FFI_DEFINE_HANDLER(
323324
ffi::Ffi::Bind()
324325
.Ctx<ffi::PlatformStream<cudaStream_t>>()
325326
.Ctx<ffi::ScratchAllocator>()
327+
.Ctx<ffi::DeviceOrdinal>()
326328
.Arg<ffi::AnyBuffer>() // dout
327329
.Arg<ffi::AnyBuffer>() // q
328330
.Arg<ffi::AnyBuffer>() // k
@@ -343,6 +345,7 @@ XLA_FFI_DEFINE_HANDLER(
343345
ffi::Ffi::Bind()
344346
.Ctx<ffi::PlatformStream<cudaStream_t>>()
345347
.Ctx<ffi::ScratchAllocator>()
348+
.Ctx<ffi::DeviceOrdinal>()
346349
.Arg<ffi::AnyBuffer>() // q
347350
.Arg<ffi::AnyBuffer>() // k
348351
.Arg<ffi::AnyBuffer>() // v
@@ -366,6 +369,7 @@ XLA_FFI_DEFINE_HANDLER(
366369
ffi::Ffi::Bind()
367370
.Ctx<ffi::PlatformStream<cudaStream_t>>()
368371
.Ctx<ffi::ScratchAllocator>()
372+
.Ctx<ffi::DeviceOrdinal>()
369373
.Arg<ffi::AnyBuffer>() // dout
370374
.Arg<ffi::AnyBuffer>() // q
371375
.Arg<ffi::AnyBuffer>() // k

csrc/flash_attn/mha_bwd.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
110110
}
111111

112112
ffi::Error mha_bwd_impl(cudaStream_t stream, ffi::ScratchAllocator scratch,
113+
int32_t device,
113114
ffi::AnyBuffer dout, // batch_size x seqlen_q x num_heads x head_size_og
114115
ffi::AnyBuffer q, // batch_size x seqlen_q x num_heads x head_size
115116
ffi::AnyBuffer k, // batch_size x seqlen_k x num_heads_k x head_size
@@ -121,9 +122,8 @@ ffi::Error mha_bwd_impl(cudaStream_t stream, ffi::ScratchAllocator scratch,
121122
ffi::Result<ffi::AnyBuffer> dv, // batch_size x seqlen_k x num_heads_k x head_size
122123
double softmax_scale, bool is_causal,
123124
int64_t window_size_left, int64_t window_size_right) {
124-
int device, major, minor, sm_count;
125-
FFI_CUDA_CHECK(cudaStreamGetDevice(stream, &device));
126-
FFI_CUDA_CHECK(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
125+
int major, minor, sm_count;
126+
FFI_CUDA_CHECK(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
127127
FFI_CUDA_CHECK(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device));
128128
FFI_CUDA_CHECK(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device));
129129

@@ -249,6 +249,7 @@ ffi::Error
249249
mha_varlen_bwd_impl(
250250
cudaStream_t stream,
251251
ffi::ScratchAllocator scratch,
252+
int32_t device,
252253
ffi::AnyBuffer dout, // total_q x num_heads, x head_size
253254
ffi::AnyBuffer q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
254255
ffi::AnyBuffer k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
@@ -270,8 +271,7 @@ mha_varlen_bwd_impl(
270271
bool deterministic) {
271272

272273
if (is_causal) { window_size_right = 0; }
273-
int device, major, minor, sm_count;
274-
FFI_CUDA_CHECK(cudaStreamGetDevice(stream, &device));
274+
int major, minor, sm_count;
275275
FFI_CUDA_CHECK(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
276276
FFI_CUDA_CHECK(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device));
277277
FFI_CUDA_CHECK(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device));

csrc/flash_attn/mha_bwd.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
namespace ffi = xla::ffi;
1111

1212
ffi::Error mha_bwd_impl(cudaStream_t stream, ffi::ScratchAllocator scratch,
13+
int32_t device,
1314
ffi::AnyBuffer dout, ffi::AnyBuffer q, ffi::AnyBuffer k,
1415
ffi::AnyBuffer v, ffi::AnyBuffer o,
1516
ffi::Buffer<ffi::F32> lse, ffi::Result<ffi::AnyBuffer> dq,
@@ -21,6 +22,7 @@ ffi::Error
2122
mha_varlen_bwd_impl(
2223
cudaStream_t stream,
2324
ffi::ScratchAllocator scratch,
25+
int32_t device,
2426
ffi::AnyBuffer dout, // total_q x num_heads, x head_size
2527
ffi::AnyBuffer q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
2628
ffi::AnyBuffer k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i

csrc/flash_attn/mha_fwd.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,19 @@ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split
3030
}
3131

3232

33-
ffi::Error mha_fwd_impl(cudaStream_t stream, ffi::ScratchAllocator scratch, ffi::AnyBuffer q, ffi::AnyBuffer k,
34-
ffi::AnyBuffer v, ffi::Result<ffi::AnyBuffer> o,
35-
ffi::ResultBuffer<ffi::F32> lse, double softmax_scale,
36-
bool is_causal, int64_t window_size_left, int64_t window_size_right) {
37-
int device, major, minor;
38-
FFI_CUDA_CHECK(cudaStreamGetDevice(stream, &device));
33+
ffi::Error mha_fwd_impl(cudaStream_t stream,
34+
ffi::ScratchAllocator scratch,
35+
int32_t device,
36+
ffi::AnyBuffer q,
37+
ffi::AnyBuffer k,
38+
ffi::AnyBuffer v,
39+
ffi::Result<ffi::AnyBuffer> o,
40+
ffi::ResultBuffer<ffi::F32> lse,
41+
double softmax_scale,
42+
bool is_causal,
43+
int64_t window_size_left,
44+
int64_t window_size_right) {
45+
int major, minor;
3946
FFI_CUDA_CHECK(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
4047
FFI_CUDA_CHECK(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device));
4148

@@ -151,6 +158,7 @@ ffi::Error
151158
mha_varlen_fwd_impl(
152159
cudaStream_t stream,
153160
ffi::ScratchAllocator scratch,
161+
int32_t device,
154162
ffi::AnyBuffer q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
155163
ffi::AnyBuffer k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
156164
ffi::AnyBuffer v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
@@ -187,8 +195,7 @@ mha_varlen_fwd_impl(
187195
// const bool return_softmax,
188196
// c10::optional<at::Generator> gen_) {
189197

190-
int device, major, minor, sm_count;
191-
FFI_CUDA_CHECK(cudaStreamGetDevice(stream, &device));
198+
int major, minor, sm_count;
192199
FFI_CUDA_CHECK(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
193200
FFI_CUDA_CHECK(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device));
194201
FFI_CUDA_CHECK(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device));

csrc/flash_attn/mha_fwd.h

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,25 @@
1010

1111
namespace ffi = xla::ffi;
1212

13-
ffi::Error mha_fwd_impl(cudaStream_t stream, ffi::ScratchAllocator scratch,
14-
ffi::AnyBuffer q, ffi::AnyBuffer k, ffi::AnyBuffer v,
15-
ffi::Result<ffi::AnyBuffer> o,
16-
ffi::ResultBuffer<ffi::F32> lse, double softmax_scale,
17-
bool is_causal, int64_t window_size_left, int64_t window_size_right);
13+
ffi::Error mha_fwd_impl(
14+
cudaStream_t stream,
15+
ffi::ScratchAllocator scratch,
16+
int32_t device,
17+
ffi::AnyBuffer q,
18+
ffi::AnyBuffer k,
19+
ffi::AnyBuffer v,
20+
ffi::Result<ffi::AnyBuffer> o,
21+
ffi::ResultBuffer<ffi::F32> lse,
22+
double softmax_scale,
23+
bool is_causal,
24+
int64_t window_size_left,
25+
int64_t window_size_right);
1826

1927
ffi::Error
2028
mha_varlen_fwd_impl(
2129
cudaStream_t stream,
2230
ffi::ScratchAllocator scratch,
31+
int32_t device,
2332
ffi::AnyBuffer q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
2433
ffi::AnyBuffer k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
2534
ffi::AnyBuffer v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i

0 commit comments

Comments
 (0)