Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 4 additions & 15 deletions custom_ops/gpu_ops/append_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ void AppendAttentionKernel(
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& set_max_lengths,
const paddle::Tensor& max_len_kv,
paddle::Tensor& fmha_out,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& attn_mask,
Expand Down Expand Up @@ -103,6 +102,7 @@ void AppendAttentionKernel(
int max_dec_len_this_time = set_max_lengths.data<int>()[2];
int max_enc_dec_len_this_time = set_max_lengths.data<int>()[3];
int max_just_dec_len_this_time = set_max_lengths.data<int>()[4];
int max_kv_len_this_time = set_max_lengths.data<int>()[8];

auto main_stream = qkv.stream();
static cudaEvent_t main_event;
Expand Down Expand Up @@ -245,7 +245,6 @@ void AppendAttentionKernel(

if (max_just_dec_len_this_time > 0) {
int decoder_num_blocks_data = decoder_num_blocks.data<int>()[0];
int max_len_kv_data = max_len_kv.data<int>()[0];

cudaStream_t exec_stream;
if (max_enc_len_this_time > 0) {
Expand Down Expand Up @@ -371,20 +370,20 @@ void AppendAttentionKernel(
case paddle::DataType::INT8:{
int8_t tmp;
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream);
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
break;
}
case paddle::DataType::FLOAT8_E4M3FN:{
phi::dtype::float8_e4m3fn tmp;
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream);
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
break;
}
}
} else {
data_t tmp;
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream);
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
}
if (max_enc_len_this_time > 0) {
cudaEventRecord(decoder_event, exec_stream);
Expand Down Expand Up @@ -413,7 +412,6 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& set_max_lengths,
const paddle::Tensor& max_len_kv,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& qkv_bias,
Expand Down Expand Up @@ -539,7 +537,6 @@ std::vector<paddle::Tensor> AppendAttention(
decoder_tile_ids_per_batch,
decoder_num_blocks,
set_max_lengths,
max_len_kv,
fmha_out,
rotary_embs,
attn_mask,
Expand Down Expand Up @@ -616,7 +613,6 @@ void AppendAttentionWithOutput(
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& set_max_lengths,
const paddle::Tensor& max_len_kv,
paddle::Tensor& fmha_out,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& attn_mask,
Expand Down Expand Up @@ -695,7 +691,6 @@ void AppendAttentionWithOutput(
decoder_tile_ids_per_batch,
decoder_num_blocks,
set_max_lengths,
max_len_kv,
fmha_out,
rotary_embs,
attn_mask,
Expand Down Expand Up @@ -784,7 +779,6 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
const std::vector<int64_t>& decoder_num_blocks_shape,
const std::vector<int64_t>& set_max_lengths_shape,
const std::vector<int64_t>& max_len_kv_shape,
const paddle::optional<std::vector<int64_t>>& rotary_embs_shape,
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
const paddle::optional<std::vector<int64_t>>& qkv_bias_shape,
Expand Down Expand Up @@ -848,7 +842,6 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
const paddle::DataType& decoder_num_blocks_dtype,
const paddle::DataType& set_max_lengths_dtype,
const paddle::DataType& max_len_kv_dtype,
const paddle::optional<paddle::DataType>& rotary_embs_dtype,
const paddle::optional<paddle::DataType>& attn_mask_dtype,
const paddle::optional<paddle::DataType>& qkv_bias_dtype,
Expand Down Expand Up @@ -930,7 +923,6 @@ std::vector<std::vector<int64_t>> AppendAttentionWithOutputInferShape(
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
const std::vector<int64_t>& decoder_num_blocks_shape,
const std::vector<int64_t>& set_max_lengths_shape,
const std::vector<int64_t>& max_len_kv_shape,
const std::vector<int64_t>& fmha_out_shape,
const paddle::optional<std::vector<int64_t>>& rotary_embs_shape,
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
Expand Down Expand Up @@ -987,7 +979,6 @@ std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
const paddle::DataType& decoder_num_blocks_dtype,
const paddle::DataType& set_max_lengths_dtype,
const paddle::DataType& max_len_kv_dtype,
const paddle::DataType& fmha_out_dtype,
const paddle::optional<paddle::DataType>& rotary_embs_dtype,
const paddle::optional<paddle::DataType>& attn_mask_dtype,
Expand Down Expand Up @@ -1046,7 +1037,6 @@ PD_BUILD_STATIC_OP(append_attention)
"decoder_tile_ids_per_batch",
"decoder_num_blocks",
"set_max_lengths",
"max_len_kv",
paddle::Optional("rotary_embs"),
paddle::Optional("attn_mask"),
paddle::Optional("qkv_bias"),
Expand Down Expand Up @@ -1107,7 +1097,6 @@ PD_BUILD_STATIC_OP(append_attention_with_output)
"decoder_tile_ids_per_batch",
"decoder_num_blocks",
"set_max_lengths",
"max_len_kv",
"fmha_out",
paddle::Optional("rotary_embs"),
paddle::Optional("attn_mask"),
Expand Down
158 changes: 73 additions & 85 deletions custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

template <int THREADBLOCK_SIZE>
__global__ void
GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time,
GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time,
const int *seq_lens_encoder,
const int *seq_lens_this_time_merged,
const int *seq_lens_encoder_merged, const int *seq_mapping,
Expand All @@ -37,41 +37,28 @@ GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time,
int max_just_dec_merged_len_this_time_this_thread = 0;
int max_system_len_this_thread = 0;
int max_dec_len_without_system_this_thread = 0;
int max_len_kv_this_thread = 0;
for (int i = tid; i < batch_size; i += blockDim.x) {
const int seq_len_this_time = seq_lens_this_time[i];
const int seq_len_decoder = seq_lens_decoder[i];
max_len_this_time_this_thread =
max(seq_len_this_time, max_len_this_time_this_thread);
max_len_encoder_this_thread =
max(seq_lens_encoder[i], max_len_encoder_this_thread);
max_len_decoder_this_thread = max(seq_lens[i], max_len_decoder_this_thread);
max_len_decoder_this_thread = max(seq_len_decoder, max_len_decoder_this_thread);

if (seq_len_this_time <= 0)
continue;
const int max_just_dec_len_now = seq_lens_encoder[i] > 0 ? 0 : seq_lens[i];
const int max_just_dec_len_now = seq_lens_encoder[i] > 0 ? 0 : seq_len_decoder;
max_len_this_thread =
max(seq_lens[i] + seq_len_this_time, max_len_this_thread);
max(seq_len_decoder + seq_len_this_time, max_len_this_thread);
max_just_dec_len_this_thread =
max(max_just_dec_len_this_thread, max_just_dec_len_now);
if (system_lens) {
const int real_bid = seq_mapping[i];
const int system_len_now = system_lens[real_bid];
max_system_len_this_thread =
max(max_system_len_this_thread, system_len_now);
max_dec_len_without_system_this_thread =
max(max_dec_len_without_system_this_thread,
max_just_dec_len_now - system_len_now);
}
}
if (system_lens) {
for (int i = tid; i < batch_size; i += blockDim.x) {
const int ori_seq_len_this_time = seq_lens_this_time_merged[i];
if (ori_seq_len_this_time <= 0)
continue;
const int max_just_dec_merged_len_this_time_now =
seq_lens_encoder_merged[i] > 0 ? 0 : ori_seq_len_this_time;
max_just_dec_merged_len_this_time_this_thread =
max(max_just_dec_merged_len_this_time_this_thread,
max_just_dec_merged_len_this_time_now);
}

if (seq_len_decoder == 0)
continue;
max_len_kv_this_thread =
max(seq_len_this_time + seq_len_decoder, max_len_kv_this_thread);
}
int total_max_len_this_time =
BlockReduce(temp_storage)
Expand All @@ -86,23 +73,18 @@ GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time,
BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp<int>());
int total_just_dec = BlockReduce(temp_storage)
.Reduce(max_just_dec_len_this_thread, MaxOp<int>());
int total_just_dec_merged =
BlockReduce(temp_storage)
.Reduce(max_just_dec_merged_len_this_time_this_thread, MaxOp<int>());
int total_system_len = BlockReduce(temp_storage)
.Reduce(max_system_len_this_thread, MaxOp<int>());
int total_dec_len_without_system =
BlockReduce(temp_storage)
.Reduce(max_dec_len_without_system_this_thread, MaxOp<int>());
int total_max_len_kv =
BlockReduce(temp_storage).Reduce(max_len_kv_this_thread, MaxOp<int>());
if (tid == 0) {
max_lens[0] = total_max_len_this_time;
max_lens[1] = total_max_len_encoder;
max_lens[2] = total_max_len_decoder;
max_lens[3] = total;
max_lens[4] = total_just_dec;
max_lens[5] = total_just_dec_merged;
max_lens[6] = total_system_len;
max_lens[7] = total_dec_len_without_system;
max_lens[5] = max_just_dec_merged_len_this_time_this_thread;
max_lens[6] = max_system_len_this_thread;
max_lens[7] = max_dec_len_without_system_this_thread;
max_lens[8] = total_max_len_kv;
}
}

Expand Down Expand Up @@ -208,25 +190,68 @@ __global__ void split_q_block(const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_encoder,
int *__restrict__ batch_ids,
int *__restrict__ tile_ids_per_batch,
int *__restrict__ num_blocks_x, const int bsz,
int *__restrict__ num_blocks_x,
const int bsz,
const int num_rows_per_block,
const int group_size) {
if (threadIdx.x == 0) {
int gridx = 0;
int index = 0;
for (uint32_t bid = 0; bid < bsz; bid++) {
// one block one warp
const int lane = threadIdx.x % warpSize;

__shared__ int global_offset;
if (threadIdx.x == 0) global_offset = 0;
__syncthreads();

// loop on warp tile:[base, base+32)
for (int base = 0; base < bsz; base += warpSize) {
const int bid = base + lane;
const bool active = (bid < bsz);

// calculate loop_times for bid
int loop_times = 0;
if (active) {
int seq_len = seq_lens_q[bid];
if (seq_lens_encoder && seq_lens_encoder[bid] > 0) {
seq_len = 0;
}
const int loop_times = div_up(seq_len * group_size, num_rows_per_block);
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
batch_ids[index] = bid;
tile_ids_per_batch[index++] = tile_id;
loop_times = div_up(seq_len * group_size, num_rows_per_block);
}

// prefix sum for each lane, get the start offset in this tile
unsigned mask = __ballot_sync(0xffffffff, active);
// inclusive scan
int x = loop_times;
for (int offset = 1; offset < warpSize; offset <<= 1) {
int y = __shfl_up_sync(mask, x, offset);
if (lane >= offset) x += y;
}
int excl = x - loop_times; // exclusive prefix
int tile_sum = __reduce_add_sync(mask, loop_times); // warp tile sum

// write batch_ids and tile_ids_per_batch
int base_offset;
if (lane == 0) {
base_offset = global_offset;
}
base_offset = __shfl_sync(mask, base_offset, 0);
if (active && loop_times > 0) {
int write_base = base_offset + excl;
// [write_base, write_base+loop_times)
for (int t = 0; t < loop_times; ++t) {
int pos = write_base + t;
batch_ids[pos] = bid;
tile_ids_per_batch[pos] = t;
}
gridx += loop_times;
}
*num_blocks_x = gridx;

// for next warp tile
if (lane == 0) {
global_offset += tile_sum;
}
__syncthreads();
}

if (threadIdx.x == 0) {
*num_blocks_x = global_offset;
}
}

Expand Down Expand Up @@ -256,29 +281,6 @@ __global__ void split_kv_block(const int *__restrict__ seq_lens_decoder,
}
}

template <int THREADBLOCK_SIZE>
__global__ void
get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time,
const int *seq_lens_decoder, const int batch_size) {
const int tid = threadIdx.x;

typedef cub::BlockReduce<int, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

int max_len_this_thread = 0;
for (int i = tid; i < batch_size; i += blockDim.x) {
if (seq_lens_decoder[i] == 0)
continue;
max_len_this_thread =
max(seq_lens_this_time[i] + seq_lens_decoder[i], max_len_this_thread);
}
int total =
BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp<int>());
if (tid == 0) {
*max_seq_lens_out = total;
}
}

void GetBlockShapeAndSplitKVBlock(
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
Expand All @@ -295,7 +297,6 @@ void GetBlockShapeAndSplitKVBlock(
paddle::Tensor &kv_batch_ids, // Inplace
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU
paddle::Tensor &max_len_kv_cpu, // Inplace, CPU
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
Expand All @@ -319,15 +320,7 @@ void GetBlockShapeAndSplitKVBlock(
int max_just_dec_merged_len_this_time = max_len_cpu_ptr[5];
int max_system_len = max_len_cpu_ptr[6];
int max_just_dec_len_without_system = max_len_cpu_ptr[7];

auto max_len_kv =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>(
max_len_kv.data<int>(), seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(), bsz);


max_len_kv_cpu.copy_(max_len_kv, max_len_kv_cpu.place(), false);
int max_kv_len_this_time = max_len_cpu_ptr[8];

// decoder
if (max_dec_len_this_time > 0) {
Expand Down Expand Up @@ -416,12 +409,8 @@ void GetBlockShapeAndSplitKVBlock(

decoder_num_blocks_cpu.copy_(
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
}
} else {
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
decoder_num_blocks_cpu.copy_(
Expand Down Expand Up @@ -479,7 +468,6 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
"kv_batch_ids",
"kv_tile_ids_per_batch",
"kv_num_blocks_x_cpu",
"max_len_kv_cpu"
})
.Outputs({

Expand Down
Loading
Loading