Skip to content

Commit ebd0368

Browse files
Make Flash Attention work on Windows (#21015)
### Description Previously, Flash Attention only worked on Linux systems. This PR will make it work and enable it to be built and run on Windows. Limitations of Flash Attention in Windows: Requires CUDA 12. ### Motivation and Context This will significantly increase the performance of Windows-based LLM's with hardware sm>=80. To illustrate the improvement of Flash Attention over Memory Efficient Attention, here are some average benchmark numbers for the GQA operator, run with configurations based on several recent models (Llama, Mixtral, Phi-3). The benchmarks were obtained on RTX4090 GPU using the test script located at (onnxruntime/test/python/transformers/benchmark_gqa_windows.py). * Clarifying Note: These benchmarks are just for the GQA operator, not the entire model. ### Memory Efficient Attention Kernel Benchmarks: | Model Name | Max Sequence Length | Inference Interval (ms) | Throughput (samples/second) | |----------------------------------------|---------------------|-------------------------|-----------------------------| | Llama3-8B (Average Prompt) | 8192 | 0.19790525 | 13105.63425 | | Llama3-8B (Average Token) | 8192 | 0.207775538 | 12025.10172 | | Llama3-70B (Average Prompt) | 8192 | 0.216049167 | 11563.31185 | | Llama3-70B (Average Token) | 8192 | 0.209730731 | 12284.38149 | | Mixtral-8x22B-v0.1 (Average Prompt) | 32768 | 0.371928785 | 7031.440056 | | Mixtral-8x22B-v0.1 (Average Token) | 32768 | 0.2996659 | 7607.947159 | | Phi-3-mini-128k (Average Prompt) | 131072 | 0.183195867 | 15542.0852 | | Phi-3-mini-128k (Average Token) | 131072 | 0.198215688 | 12874.53494 | | Phi-3-small-128k (Average Prompt) | 65536 | 2.9884929 | 2332.584142 | | Phi-3-small-128k (Average Token) | 65536 | 0.845072406 | 2877.85822 | | Phi-3-medium-128K (Average Prompt) | 32768 | 0.324974429 | 8094.909517 | | Phi-3-medium-128K (Average Token) | 32768 | 0.263662567 | 8978.463687 | ### Flash Attention Kernel Benchmarks: | Model Name | Max Sequence Length | Inference Interval (ms) | Throughput (samples/second) | |--------------------------------------|---------------------|-------------------------|-----------------------------| | Llama3-8B (Average Prompt) | 8192 | 0.163566292 | 16213.69057 | | Llama3-8B (Average Token) | 8192 | 0.161643692 | 16196.14715 | | Llama3-70B (Average Prompt) | 8192 | 0.160510375 | 17448.67753 | | Llama3-70B (Average Token) | 8192 | 0.169427308 | 14702.62043 | | Mixtral-8x22B-v0.1 (Average Prompt) | 32768 | 0.164121964 | 15618.51301 | | Mixtral-8x22B-v0.1 (Average Token) | 32768 | 0.1715865 | 14524.32273 | | Phi-3-mini-128k (Average Prompt) | 131072 | 0.167527167 | 14576.725 | | Phi-3-mini-128k (Average Token) | 131072 | 0.175940594 | 15762.051 | | Phi-3-small-128k (Average Prompt) | 65536 | 0.162719733 | 17824.494 | | Phi-3-small-128k (Average Token) | 65536 | 0.14977525 | 16749.19858 | | Phi-3-medium-128K (Average Prompt) | 32768 | 0.156490786 | 17679.2513 | | Phi-3-medium-128K (Average Token) | 32768 | 0.165333833 | 14932.26079 | Flash Attention is consistently faster for every configuration we benchmarked, with improvements in our trials ranging from ~20% to ~650%. In addition to these improvements in performance, Flash Attention has better memory usage. For example, Memory Efficient Attention cannot handle a max sequence length higher than 32,768, but Flash Attention can handle max sequence lengths at least as high as 131,072. --------- Co-authored-by: Tianlei Wu <[email protected]>
1 parent 269d9b0 commit ebd0368

34 files changed

+1397
-862
lines changed

.github/workflows/lint.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ jobs:
9797
--exclude=java/src/main/native/*.c
9898
--exclude=onnxruntime/core/mlas/inc/*
9999
--exclude=onnxruntime/core/mlas/lib/*
100+
--exclude=onnxruntime/contrib_ops/cuda/bert/flash_attention/*
100101
filter: "-runtime/references"
101102

102103
lint-js:

.lintrunner.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ exclude_patterns = [
136136
'onnxruntime/core/mickey/cutlass_ext/**', # CUTLASS based libs recommends NO automatic code formatting
137137
'onnxruntime/core/mickey/gemm/**', # CUTLASS based libs recommends NO automatic code formatting
138138
'winml/lib/Api.Image/shaders/**', # Contains data chunks
139+
'onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h', # Bool Switches hang Clang
139140
]
140141
command = [
141142
'python',

cmake/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov
102102
option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF)
103103
option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF)
104104

105-
cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF)
105+
cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
106106
option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)
107107

108108
option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF)
@@ -734,6 +734,9 @@ if (onnxruntime_USE_CUDA)
734734
message( STATUS "Turn off flash attention since CUDA compiler version < 11.6")
735735
set(onnxruntime_USE_FLASH_ATTENTION OFF)
736736
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
737+
elseif(WIN32 AND CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12)
738+
message( STATUS "Flash-Attention unsupported in Windows with CUDA compiler version < 12.0")
739+
set(onnxruntime_USE_FLASH_ATTENTION OFF)
737740
endif()
738741
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4)
739742
message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4")

onnxruntime/contrib_ops/cuda/bert/attention.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
145145
auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes(
146146
parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads,
147147
parameters.head_size, device_prop.multiProcessorCount);
148-
parameters.num_splits = num_splits;
148+
parameters.num_splits = static_cast<int>(num_splits);
149149
softmax_lse_accum_bytes = slse_accum_bytes;
150150
out_accum_bytes = o_accum_bytes;
151151
}

onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,11 @@ Status FlashAttention(
334334
contrib::AttentionParameters& parameters,
335335
AttentionData<float>& data,
336336
float scale) {
337+
ORT_UNUSED_PARAMETER(device_prop);
338+
ORT_UNUSED_PARAMETER(stream);
339+
ORT_UNUSED_PARAMETER(parameters);
340+
ORT_UNUSED_PARAMETER(data);
341+
ORT_UNUSED_PARAMETER(scale);
337342
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "flash attention does not support float tensor");
338343
}
339344
#endif
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#include <cmath>
2+
#include <cute/tensor.hpp>
3+
#include <cutlass/cutlass.h>
4+
#include <cutlass/array.h>
5+
#include "utils.h"
6+
7+
namespace onnxruntime {
8+
namespace flash {
9+
10+
using namespace cute;
11+
12+
////////////////////////////////////////////////////////////////////////////////////////////////////
13+
14+
template <bool Is_causal>
15+
struct Alibi {
16+
const float alibi_slope;
17+
const int max_seqlen_k, max_seqlen_q;
18+
19+
__forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q)
20+
: alibi_slope(alibi_slope), max_seqlen_k(max_seqlen_k), max_seqlen_q(max_seqlen_q){};
21+
22+
template <typename Engine, typename Layout>
23+
__forceinline__ __device__ void apply_alibi(Tensor<Engine, Layout>& tensor,
24+
const int col_idx_offset_,
25+
const int row_idx_offset,
26+
const int warp_row_stride) {
27+
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
28+
static_assert(Layout::rank == 2, "Only support 2D Tensor");
29+
const int lane_id = threadIdx.x % 32;
30+
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
31+
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
32+
#pragma unroll
33+
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
34+
const int col_idx_base = col_idx_offset + nj * 8;
35+
#pragma unroll
36+
for (int j = 0; j < size<1, 0>(tensor); ++j) {
37+
const int col_idx = col_idx_base + j;
38+
#pragma unroll
39+
for (int mi = 0; mi < size<0>(tensor); ++mi) {
40+
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
41+
}
42+
}
43+
}
44+
} else { // Bias depends on both row_idx and col_idx
45+
#pragma unroll
46+
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
47+
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
48+
#pragma unroll
49+
for (int i = 0; i < size<0, 0>(tensor); ++i) {
50+
const int row_idx = row_idx_base + i * 8;
51+
#pragma unroll
52+
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
53+
const int col_idx_base = col_idx_offset + nj * 8;
54+
#pragma unroll
55+
for (int j = 0; j < size<1, 0>(tensor); ++j) {
56+
const int col_idx = col_idx_base + j;
57+
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
58+
}
59+
}
60+
}
61+
}
62+
}
63+
}
64+
};
65+
66+
} // namespace flash
67+
} // namespace onnxruntime

onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,36 @@ struct BlockInfo {
1212
template <typename Params>
1313
__device__ BlockInfo(const Params& params, const int bidb)
1414
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]),
15-
sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]),
16-
actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
15+
sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative
16+
? -1
17+
: params.cu_seqlens_k[bidb]),
18+
actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr
19+
? params.seqlen_q
20+
: params.cu_seqlens_q[bidb + 1] - sum_s_q)
1721
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
1822
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
1923
,
20-
seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])),
21-
actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) {
24+
seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr
25+
? params.seqlen_k
26+
: (params.is_seqlens_k_cumulative
27+
? params.cu_seqlens_k[bidb + 1] - sum_s_k
28+
: params.cu_seqlens_k[bidb])),
29+
actual_seqlen_k(params.seqused_k
30+
? params.seqused_k[bidb]
31+
: seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) {
2232
}
2333

2434
template <typename index_t>
25-
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
35+
__forceinline__ __device__
36+
index_t
37+
q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
2638
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
2739
}
2840

2941
template <typename index_t>
30-
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
42+
__forceinline__ __device__
43+
index_t
44+
k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
3145
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
3246
}
3347

@@ -41,6 +55,5 @@ struct BlockInfo {
4155

4256
////////////////////////////////////////////////////////////////////////////////////////////////////
4357

44-
////////////////////////////////////////////////////////////////////////////////////////////////////
4558
} // namespace flash
4659
} // namespace onnxruntime

onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ constexpr int D_DIM = 2;
1616
////////////////////////////////////////////////////////////////////////////////////////////////////
1717

1818
struct Qkv_params {
19-
using index_t = uint32_t;
19+
using index_t = int64_t;
2020
// The QKV matrices.
2121
void* __restrict__ q_ptr = nullptr;
2222
void* __restrict__ k_ptr = nullptr;
@@ -79,6 +79,9 @@ struct Flash_fwd_params : public Qkv_params {
7979
int* __restrict__ cu_seqlens_q = nullptr;
8080
int* __restrict__ cu_seqlens_k = nullptr;
8181

82+
// If provided, the actual length of each k sequence.
83+
int* __restrict__ seqused_k = nullptr;
84+
8285
int* __restrict__ blockmask = nullptr;
8386

8487
// The K_new and V_new matrices.
@@ -100,6 +103,11 @@ struct Flash_fwd_params : public Qkv_params {
100103
// The indices to index into the KV cache.
101104
int* __restrict__ cache_batch_idx = nullptr;
102105

106+
// Paged KV cache
107+
int* __restrict__ block_table = nullptr;
108+
index_t block_table_batch_stride = 0;
109+
int page_block_size = 0;
110+
103111
// Local window size
104112
int window_size_left = -1;
105113
int window_size_right = -1;
@@ -115,6 +123,9 @@ struct Flash_fwd_params : public Qkv_params {
115123

116124
int num_splits = 0; // For split-KV version
117125

126+
void* __restrict__ alibi_slopes_ptr = nullptr;
127+
index_t alibi_slopes_batch_stride = 0;
128+
118129
const cudaDeviceProp* dprops = nullptr;
119130
};
120131

0 commit comments

Comments
 (0)