Skip to content

Commit 69cfcba

Browse files
authored
[CUDA] Sparse Attention support 128k sequence length (#20614)
### Description When sequence length is 128K, block_mask has 2048 rows, that is not supported by previous kernel. (1) Add a new kernel to handle more than 1024 rows, and each thread need handle two rows. (2) Add a test for sequence length 128k.
1 parent a0db218 commit 69cfcba

File tree

2 files changed

+133
-23
lines changed

2 files changed

+133
-23
lines changed

onnxruntime/contrib_ops/cuda/sparse/block_mask.cu

Lines changed: 73 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,27 +27,28 @@ __global__ void MaskToCSR(const int* mask, int* csr_row_indices, int* csr_col_in
2727
}
2828
}
2929

30-
extern __shared__ int non_zero_counts[];
31-
non_zero_counts[threadIdx.x] = count;
30+
extern __shared__ int shared_row_indices[];
31+
shared_row_indices[row + 1] = count;
3232
__syncthreads();
3333

3434
// The first thread will calculate the accumulated partial sum of non-zero counts.
35+
// The result is csr_row_indices stored in shared memory.
3536
if (row == 0) {
37+
shared_row_indices[0] = 0;
3638
for (int i = 1; i < num_rows; i++) {
37-
non_zero_counts[i] += non_zero_counts[i - 1];
39+
shared_row_indices[i + 1] += shared_row_indices[i];
3840
}
41+
42+
// The first thread outputs the last element.
43+
csr_row_indices[num_rows] = shared_row_indices[num_rows];
3944
}
4045
__syncthreads();
4146

42-
// The starting index of current row in csr_col_indices
43-
int offset = (row == 0) ? 0 : non_zero_counts[row - 1];
47+
// The starting index of current row in csr_col_indices
48+
int offset = shared_row_indices[row];
4449

4550
// Output row indices.
4651
csr_row_indices[row] = offset;
47-
if (row == 0) {
48-
// The first thread output the last element.
49-
csr_row_indices[num_rows] = non_zero_counts[num_rows - 1];
50-
}
5152

5253
for (int col = 0; col < num_cols; col++) {
5354
if (mask[row * num_cols + col] == 1) {
@@ -60,6 +61,59 @@ __global__ void MaskToCSR(const int* mask, int* csr_row_indices, int* csr_col_in
6061
// The last element of csr_row_indices is the total number of non-zero elements.
6162
}
6263

64+
__global__ void MaskToCSR_Large(const int* mask,
65+
int* csr_row_indices,
66+
int* csr_col_indices,
67+
int num_rows,
68+
int num_cols,
69+
int rows_per_thread // Each thread handles multiple rows
70+
) {
71+
extern __shared__ int shared_row_indices[];
72+
73+
// Update input and output data pointers to the start of current head
74+
int head = blockIdx.x;
75+
mask += head * num_rows * num_cols;
76+
csr_row_indices += head * (num_rows + 1);
77+
csr_col_indices += head * num_rows * num_cols;
78+
79+
int tid = threadIdx.x;
80+
for (int row = tid * rows_per_thread; row < num_rows && row < (tid + 1) * rows_per_thread; row++) {
81+
int count = 0;
82+
for (int col = 0; col < num_cols; col++) {
83+
if (mask[row * num_cols + col] == 1) {
84+
count++;
85+
}
86+
}
87+
shared_row_indices[row + 1] = count;
88+
}
89+
90+
__syncthreads();
91+
92+
// The first thread will calculate the accumulated partial sum of non-zero counts.
93+
if (tid == 0) {
94+
shared_row_indices[0] = 0;
95+
for (int i = 1; i < num_rows; i++) {
96+
shared_row_indices[i + 1] += shared_row_indices[i];
97+
}
98+
99+
csr_row_indices[num_rows] = shared_row_indices[num_rows];
100+
}
101+
102+
__syncthreads();
103+
104+
for (int row = tid * rows_per_thread; row < num_rows && row < (tid + 1) * rows_per_thread; row++) {
105+
int offset = shared_row_indices[row];
106+
csr_row_indices[row] = offset;
107+
108+
for (int col = 0; col < num_cols; col++) {
109+
if (mask[row * num_cols + col] == 1) {
110+
csr_col_indices[offset] = col;
111+
offset++;
112+
}
113+
}
114+
}
115+
}
116+
63117
void ConvertMaskToCSR(cudaStream_t stream,
64118
const int* mask, // input mask with shape (num_layout, num_rows, num_cols)
65119
int num_layout, // number of layouts
@@ -68,15 +122,17 @@ void ConvertMaskToCSR(cudaStream_t stream,
68122
int* csr_row_indices, // output CSR row indices
69123
int* csr_col_indices, // output CSR column indices
70124
int max_threads_per_block) {
71-
int threads_per_block = (num_rows + 31) / 32 * 32;
72-
73-
// Each thread handle one row. The kernel assumes that all rows of one head can be handled in one block.
74-
if (threads_per_block > max_threads_per_block) {
75-
ORT_THROW("num_rows is too large: num_rows=", num_rows, ", max_threads_per_block=", max_threads_per_block);
125+
if (num_rows <= max_threads_per_block) {
126+
// Each thread handle one row.
127+
MaskToCSR<<<num_layout, num_rows, (num_rows + 1) * sizeof(int), stream>>>(
128+
mask, csr_row_indices, csr_col_indices, num_rows, num_cols);
129+
} else {
130+
// Each thread will handle multiple rows when number of rows > max_threads_per_block.
131+
// For example 128K length with sparse block size 64 will have 2048 rows. Each thread will handle 2 rows.
132+
int rows_per_thread = (num_rows + max_threads_per_block - 1) / max_threads_per_block;
133+
MaskToCSR_Large<<<num_layout, max_threads_per_block, (num_rows + 1) * sizeof(int), stream>>>(
134+
mask, csr_row_indices, csr_col_indices, num_rows, num_cols, rows_per_thread);
76135
}
77-
78-
MaskToCSR<<<num_layout, threads_per_block, threads_per_block * sizeof(int), stream>>>(
79-
mask, csr_row_indices, csr_col_indices, num_rows, num_cols);
80136
}
81137

82138
} // namespace cuda

onnxruntime/test/python/transformers/test_sparse_attention.py

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
# --------------------------------------------------------------------------
55

66
"""
7-
Parity test and benchmark performance of SparseAttention. Requires Nvidia GPU of Compute Capability 8.x.
8-
Install required packages before running this script:
9-
pip install matplotlib pandas onnx torch onnxruntime-gpu
7+
Parity test and benchmark performance of SparseAttention. Requires Nvidia GPU of Compute Capability 7.5 or above.
108
"""
119
import math
1210
import unittest
@@ -726,13 +724,13 @@ def test_sparse_attention(self):
726724
self.run_relevance_test(sm)
727725

728726
def run_one_relevance_test(self, config: SparseAttentionConfig):
729-
if not config.do_rotary:
730-
# Run QGA by Torch
727+
if (not config.do_rotary) and config.total_sequence_length <= 2048:
728+
# Run QGA by Torch (support mask, but not packed QKV, rotary and very long sequence)
731729
gqa_config: GroupQueryAttentionConfig = config.get_comparable_torch_gqa_config(use_sparse=True)
732730
obj = TorchGroupQueryAttention(gqa_config)
733731
expected_out = obj.infer()
734732
else:
735-
# Run QGA by ORT
733+
# Run QGA by ORT (support packed QKV, rotary and very long sequence, but no mask so dense only).
736734
gqa_config: GroupQueryAttentionConfig = config.get_comparable_ort_gqa_config(use_local=False)
737735
obj = OrtGroupQueryAttention(gqa_config)
738736
ort_qga_outputs = obj.infer()
@@ -820,10 +818,66 @@ def run_relevance_past(self, sm: int, device, do_rotary: bool):
820818
config.dtype = torch.bfloat16
821819
self.run_one_relevance_test(config)
822820

821+
def run_relevance_no_past_128k(self, sm: int, device):
822+
"""Test kernel could support up to 128K sequence length."""
823+
for seq_len in [131072]:
824+
for packed_qkv in [False, True]:
825+
config = SparseAttentionConfig(
826+
batch_size=1,
827+
sequence_length=seq_len,
828+
max_sequence_length=131072,
829+
past_sequence_length=0,
830+
num_heads=1,
831+
kv_num_heads=1,
832+
head_size=128,
833+
sparse_block_size=64,
834+
num_layout=1,
835+
local_blocks=2048, # use dense to compare with GQA
836+
vert_stride=8,
837+
softmax_scale=None,
838+
device=device,
839+
is_packed_qkv=packed_qkv,
840+
)
841+
self.run_one_relevance_test(config)
842+
843+
if sm >= 80 and not packed_qkv:
844+
config.dtype = torch.bfloat16
845+
self.run_one_relevance_test(config)
846+
847+
def run_relevance_past_128k(self, sm: int, device):
848+
"""Test kernel could support up to 128K sequence length."""
849+
for past_seq_len in [131071]:
850+
for packed_qkv in [False, True]:
851+
config = SparseAttentionConfig(
852+
batch_size=1,
853+
sequence_length=1,
854+
max_sequence_length=131072,
855+
past_sequence_length=past_seq_len,
856+
num_heads=1,
857+
kv_num_heads=1,
858+
head_size=128,
859+
sparse_block_size=64,
860+
num_layout=1,
861+
local_blocks=2048, # use dense to compare with GQA
862+
vert_stride=8,
863+
softmax_scale=None,
864+
device=device,
865+
is_packed_qkv=packed_qkv,
866+
)
867+
self.run_one_relevance_test(config)
868+
869+
if sm >= 80 and not packed_qkv:
870+
config.dtype = torch.bfloat16
871+
self.run_one_relevance_test(config)
872+
823873
def run_relevance_test(self, sm: int):
824874
device_id = torch.cuda.current_device()
825875
device = torch.device("cuda", device_id)
826876
with torch.no_grad():
877+
# Test long sequence when GPU memory is enough (need about 12 GB for 128K sequence length)
878+
if torch.cuda.get_device_properties(device_id).total_memory > 13 * 1024 * 1024 * 1024:
879+
self.run_relevance_no_past_128k(sm, device)
880+
self.run_relevance_past_128k(sm, device)
827881
self.run_relevance_no_past(sm, device)
828882
self.run_relevance_past(sm, device, do_rotary=False)
829883
self.run_relevance_past(sm, device, do_rotary=True)

0 commit comments

Comments
 (0)