diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index ca5ce25474..4e86052c3e 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -38,7 +38,7 @@ class type2value { template -std::vector AppendAttentionKernel( +void AppendAttentionKernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, const paddle::Tensor& key_cache, @@ -60,6 +60,7 @@ std::vector AppendAttentionKernel( const paddle::Tensor& decoder_num_blocks, const paddle::Tensor& set_max_lengths, const paddle::Tensor& max_len_kv, + paddle::Tensor& fmha_out, const paddle::optional& rotary_embs, const paddle::optional& attn_mask, const paddle::optional& qkv_bias, @@ -122,27 +123,6 @@ std::vector AppendAttentionKernel( } else { qkv_out = qkv; } - paddle::Tensor fmha_out; - if (out_linear_in_scale > 0.0) { - if (fabs(quant_max_bound - 127.0f) < 0.000001) { - fmha_out = GetEmptyTensor( - {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, - paddle::DataType::INT8, - qkv.place()); - } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { - fmha_out = GetEmptyTensor( - {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, - paddle::DataType::FLOAT8_E4M3FN, - qkv.place()); - }else{ - PD_THROW("Only supported attr of quant_max_bound in ['127', '448']."); - } - } else { - fmha_out = GetEmptyTensor( - {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, - D, - qkv.place()); - } auto dispatch_CascadeAppendAttentionKernel = [&](auto temp_args, const paddle::Tensor& lambda_batch_ids, @@ -405,8 +385,6 @@ std::vector AppendAttentionKernel( cudaStreamWaitEvent(main_stream, decoder_event); } } - - return {fmha_out, qkv_out}; } std::vector AppendAttention( @@ -481,12 +459,60 @@ std::vector AppendAttention( meta_data.block_size = key_cache.dims()[2]; meta_data.batch_size = seq_lens_this_time.dims()[0]; + // template dtype generation + phi::DataType dtype_id; + switch (qkv.dtype()) { + case paddle::DataType::FLOAT16: {dtype_id = phi::DataType::FLOAT16; break;} + case paddle::DataType::BFLOAT16: {dtype_id = phi::DataType::BFLOAT16; break;} + case paddle::DataType::INT32: { + if (compute_dtype == "bf16") { + dtype_id = phi::DataType::BFLOAT16; + break; + } else if (compute_dtype == "fp16") { + dtype_id = phi::DataType::FLOAT16; + break; + } else { + PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16']."); + break; + } + } + default: { + PD_THROW( + "NOT supported data type. " + "Only float16 and bfloat16 are supported. "); + break; + } + } + + // fmha_out generation, rewrite from AppendAttentionKernel + paddle::Tensor fmha_out; + if (out_linear_in_scale > 0.0) { + if (fabs(quant_max_bound - 127.0f) < 0.000001) { + fmha_out = GetEmptyTensor( + {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, + paddle::DataType::INT8, + qkv.place()); + } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { + fmha_out = GetEmptyTensor( + {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, + paddle::DataType::FLOAT8_E4M3FN, + qkv.place()); + } else{ + PD_THROW("Only supported attr of quant_max_bound in ['127', '448']."); + } + } else { + fmha_out = GetEmptyTensor( + {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, + dtype_id, + qkv.place()); + } + if (mask_offset) { meta_data.mask_offset = mask_offset.get().data(); } - auto dispatch_by_template = [&](auto temp_args) -> std::vector { - return AppendAttentionKernel::value>( + auto dispatch_by_template = [&](auto temp_args) -> void { + AppendAttentionKernel::value>( meta_data, qkv, key_cache, @@ -508,6 +534,7 @@ std::vector AppendAttention( decoder_num_blocks, set_max_lengths, max_len_kv, + fmha_out, rotary_embs, attn_mask, qkv_bias, @@ -539,20 +566,183 @@ std::vector AppendAttention( speculate_max_draft_token_num, causal, speculate_decoder); + }; + + + phi::dtype::float16 fp16_dtype; + phi::dtype::bfloat16 bp16_dtype; + switch (dtype_id){ + case phi::DataType::FLOAT16: { + dispatch_by_template(fp16_dtype); + return {fmha_out}; + } + case phi::DataType::BFLOAT16: { + dispatch_by_template(bp16_dtype); + return {fmha_out}; + } + default: + PD_THROW( + "NOT supported data type. " + "Only float16 and bfloat16 are supported. "); + break; + } + return {paddle::Tensor{}}; +} + +void AppendAttentionWithOutput( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& encoder_batch_ids, + const paddle::Tensor& encoder_tile_ids_per_batch, + const paddle::Tensor& encoder_num_blocks, + const paddle::Tensor& kv_batch_ids, + const paddle::Tensor& kv_tile_ids_per_batch, + const paddle::Tensor& kv_num_blocks, + const paddle::Tensor& decoder_batch_ids, + const paddle::Tensor& decoder_tile_ids_per_batch, + const paddle::Tensor& decoder_num_blocks, + const paddle::Tensor& set_max_lengths, + const paddle::Tensor& max_len_kv, + paddle::Tensor& fmha_out, + const paddle::optional& rotary_embs, + const paddle::optional& attn_mask, + const paddle::optional& qkv_bias, + const paddle::optional& qkv_out_scales, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& out_linear_shifts, + const paddle::optional& out_linear_smooths, + const paddle::optional& mask_offset, + const paddle::optional& kv_signal_data, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps, + const std::string& compute_dtype, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const float out_linear_in_scale, + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool speculate_decoder) { + AppendAttnMetaData meta_data; + + const auto& qkv_dims = qkv.dims(); + const auto& key_cache_dims = key_cache.dims(); + meta_data.token_nums = qkv_dims[0]; + meta_data.kv_num_heads = key_cache_dims[1]; + meta_data.head_dims = key_cache_dims[3]; + // TODO: trick method support c4, add attr head_dims in the future + if (cache_quant_type_str == "cache_int4_zp") { + meta_data.head_dims *= 2; + } + const int total_num_head = + qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims; + meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads; + + meta_data.max_blocks_per_seq = block_tables.dims()[1]; + meta_data.block_size = key_cache.dims()[2]; + meta_data.batch_size = seq_lens_this_time.dims()[0]; + + if (mask_offset) { + meta_data.mask_offset = mask_offset.get().data(); + } + + auto dispatch_by_template = [&](auto temp_args) -> void { + AppendAttentionKernel::value>( + meta_data, + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks, + set_max_lengths, + max_len_kv, + fmha_out, + rotary_embs, + attn_mask, + qkv_bias, + qkv_out_scales, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + mask_offset, + kv_signal_data, + q_norm_weight, + k_norm_weight, + rms_norm_eps, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + quant_max_bound, + quant_min_bound, + out_linear_in_scale, + encoder_block_shape_q, + decoder_block_shape_q, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + speculate_decoder); }; phi::dtype::float16 fp16_dtype; phi::dtype::bfloat16 bp16_dtype; switch (qkv.dtype()) { - case paddle::DataType::FLOAT16: return dispatch_by_template(fp16_dtype); - case paddle::DataType::BFLOAT16: return dispatch_by_template(bp16_dtype); + case paddle::DataType::FLOAT16: { + dispatch_by_template(fp16_dtype); + break; + } + case paddle::DataType::BFLOAT16: { + dispatch_by_template(bp16_dtype); + break; + } case paddle::DataType::INT32: { if (compute_dtype == "bf16") { - return dispatch_by_template(bp16_dtype); + dispatch_by_template(bp16_dtype); + break; } else if (compute_dtype == "fp16") { - return dispatch_by_template(fp16_dtype); + dispatch_by_template(fp16_dtype); + break; } else { PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16']."); break; @@ -565,9 +755,9 @@ std::vector AppendAttention( break; } } - return {paddle::Tensor{}}; } + std::vector> AppendAttentionInferShape( const std::vector& qkv_shape, const std::vector& key_cache_shape, @@ -629,7 +819,7 @@ std::vector> AppendAttentionInferShape( } const int total_num_head = qkv_shape[qkv_shape.size() - 1] / head_dim; const int num_heads = total_num_head - 2 * kv_num_heads; - return {{token_num, num_heads * head_dim}, qkv_shape}; + return {{token_num, num_heads * head_dim}}; } std::vector AppendAttentionInferDtype( @@ -688,32 +878,148 @@ std::vector AppendAttentionInferDtype( if (compute_dtype == "bf16") { if (out_linear_in_scale > 0.0) { if (fabs(quant_max_bound - 127.0f) < 0.000001) { - return {paddle::DataType::INT8, paddle::DataType::BFLOAT16}; + return {paddle::DataType::INT8}; } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { - return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::BFLOAT16}; + return {paddle::DataType::FLOAT8_E4M3FN}; }else{ PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0']."); } } else { - return {paddle::DataType::BFLOAT16, paddle::DataType::BFLOAT16}; + return {paddle::DataType::BFLOAT16}; } } else if (compute_dtype == "fp16") { if (out_linear_in_scale > 0.0) { if (fabs(quant_max_bound - 127.0f) < 0.000001) { - return {paddle::DataType::INT8, paddle::DataType::FLOAT16}; + return {paddle::DataType::INT8}; } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { - return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT16}; + return {paddle::DataType::FLOAT8_E4M3FN}; }else{ PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0']."); } } else { - return {paddle::DataType::FLOAT16, paddle::DataType::FLOAT16}; + return {paddle::DataType::FLOAT16}; } } else { PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16']."); } } +std::vector> AppendAttentionWithOutputInferShape( + const std::vector& qkv_shape, + const std::vector& key_cache_shape, + const std::vector& value_cache_shape, + const std::vector& seq_lens_encoder_shape, + const std::vector& seq_lens_decoder_shape, + const std::vector& seq_lens_this_time_shape, + const std::vector& batch_id_per_token_shape, + const std::vector& cu_seqlens_q_shape, + const std::vector& block_tables_shape, + const std::vector& encoder_batch_ids_shape, + const std::vector& encoder_tile_ids_per_batch_shape, + const std::vector& encoder_num_blocks_shape, + const std::vector& kv_batch_ids_shape, + const std::vector& kv_tile_ids_per_batch_shape, + const std::vector& kv_num_blocks_shape, + const std::vector& decoder_batch_ids_shape, + const std::vector& decoder_tile_ids_per_batch_shape, + const std::vector& decoder_num_blocks_shape, + const std::vector& set_max_lengths_shape, + const std::vector& max_len_kv_shape, + const std::vector& fmha_out_shape, + const paddle::optional>& rotary_embs_shape, + const paddle::optional>& attn_mask_shape, + const paddle::optional>& qkv_bias_shape, + const paddle::optional>& qkv_out_scales_shape, + const paddle::optional>& cache_k_quant_scales_shape, + const paddle::optional>& cache_v_quant_scales_shape, + const paddle::optional>& cache_k_dequant_scales_shape, + const paddle::optional>& cache_v_dequant_scales_shape, + const paddle::optional>& cache_k_zp_shape, + const paddle::optional>& cache_v_zp_shape, + const paddle::optional>& out_linear_shifts_shape, + const paddle::optional>& out_linear_smooths_shape, + const paddle::optional>& mask_offset_shape, + const paddle::optional>& kv_signal_data_shape, + const paddle::optional>& q_norm_weight_shape, + const paddle::optional>& k_norm_weight_shape, + const float rms_norm_eps, + const std::string& compute_dtype, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const float out_linear_in_scale, + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool speculate_decoder) { + return {fmha_out_shape}; +} + +std::vector AppendAttentionWithOutputInferDtype( + const paddle::DataType& qkv_dtype, + const paddle::DataType& key_cache_dtype, + const paddle::DataType& value_cache_dtype, + const paddle::DataType& seq_lens_encoder_dtype, + const paddle::DataType& seq_lens_decoder_dtype, + const paddle::DataType& seq_lens_this_time_dtype, + const paddle::DataType& batch_id_per_token_dtype, + const paddle::DataType& cu_seqlens_q_dtype, + const paddle::DataType& block_tables_dtype, + const paddle::DataType& encoder_batch_ids_dtype, + const paddle::DataType& encoder_tile_ids_per_batch_dtype, + const paddle::DataType& encoder_num_blocks_dtype, + const paddle::DataType& kv_batch_ids_dtype, + const paddle::DataType& kv_tile_ids_per_batch_dtype, + const paddle::DataType& kv_num_blocks_dtype, + const paddle::DataType& decoder_batch_ids_dtype, + const paddle::DataType& decoder_tile_ids_per_batch_dtype, + const paddle::DataType& decoder_num_blocks_dtype, + const paddle::DataType& set_max_lengths_dtype, + const paddle::DataType& max_len_kv_dtype, + const paddle::DataType& fmha_out_dtype, + const paddle::optional& rotary_embs_dtype, + const paddle::optional& attn_mask_dtype, + const paddle::optional& qkv_bias_dtype, + const paddle::optional& qkv_out_scales_dtype, + const paddle::optional& cache_k_quant_scales_dtype, + const paddle::optional& cache_v_quant_scales_dtype, + const paddle::optional& cache_k_dequant_scales_dtype, + const paddle::optional& cache_v_dequant_scales_dtype, + const paddle::optional& cache_k_zp_dtype, + const paddle::optional& cache_v_zp_dtype, + const paddle::optional& out_linear_shifts_dtype, + const paddle::optional& out_linear_smooths_dtype, + const paddle::optional& mask_offset_dtype, + const paddle::optional& kv_signal_data_dtype, + const paddle::optional& q_norm_weight_dtype, + const paddle::optional& k_norm_weight_dtype, + const float rms_norm_eps, + const std::string& compute_dtype, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const float out_linear_in_scale, + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool speculate_decoder) { + return {fmha_out_dtype}; +} + + + PD_BUILD_STATIC_OP(append_attention) .Inputs({"qkv", "key_cache", @@ -774,3 +1080,65 @@ PD_BUILD_STATIC_OP(append_attention) .SetKernelFn(PD_KERNEL(AppendAttention)) .SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype)); + +PD_BUILD_STATIC_OP(append_attention_with_output) + .Inputs({"qkv", + "key_cache", + "value_cache", + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "batch_id_per_token", + "cu_seqlens_q", + "block_tables", + "encoder_batch_ids", + "encoder_tile_ids_per_batch", + "encoder_num_blocks", + "kv_batch_ids", + "kv_tile_ids_per_batch", + "kv_num_blocks", + "decoder_batch_ids", + "decoder_tile_ids_per_batch", + "decoder_num_blocks", + "set_max_lengths", + "max_len_kv", + "fmha_out", + paddle::Optional("rotary_embs"), + paddle::Optional("attn_mask"), + paddle::Optional("qkv_bias"), + paddle::Optional("qkv_out_scales"), + paddle::Optional("cache_k_quant_scales"), + paddle::Optional("cache_v_quant_scales"), + paddle::Optional("cache_k_dequant_scales"), + paddle::Optional("cache_v_dequant_scales"), + paddle::Optional("cache_k_zp"), + paddle::Optional("cache_v_zp"), + paddle::Optional("out_linear_shifts"), + paddle::Optional("out_linear_smooths"), + paddle::Optional("mask_offset"), + paddle::Optional("kv_signal_data"), + paddle::Optional("q_norm_weight"), + paddle::Optional("k_norm_weight")}) + .Outputs({"fmha_out_out", "qkv_out", "key_cache_out", "value_cache_out"}) + .SetInplaceMap({{"fmha_out", "fmha_out_out"}, + {"key_cache", "key_cache_out"}, + {"value_cache", "value_cache_out"}}) + .Attrs({"compute_type: std::string", + "cache_quant_type: std::string", + "use_neox_rotary_style: bool", + "rope_3d: bool", + "max_input_length: int", + "quant_max_bound: float", + "quant_min_bound: float", + "out_linear_in_scale: float", + "encoder_block_shape_q: int", + "decoder_block_shape_q: int", + "max_partition_size: int", + "encoder_max_partition_size: int", + "speculate_max_draft_token_num: int", + "causal: bool", + "speculate_decoder: bool", + "rms_norm_eps: float"}) + .SetKernelFn(PD_KERNEL(AppendAttentionWithOutput)) + .SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionWithOutputInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionWithOutputInferDtype)); diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index d43a4af5cb..a7952abf77 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -91,6 +91,49 @@ std::vector AppendAttention( const int speculate_max_draft_token_num, const bool causal, const bool speculate_decoder); +void AppendAttentionWithOutput( + const paddle::Tensor &qkv, const paddle::Tensor &key_cache, + const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_tables, const paddle::Tensor &encoder_batch_ids, + const paddle::Tensor &encoder_tile_ids_per_batch, + const paddle::Tensor &encoder_num_blocks, + const paddle::Tensor &kv_batch_ids, + const paddle::Tensor &kv_tile_ids_per_batch, + const paddle::Tensor &kv_num_blocks, + const paddle::Tensor &decoder_batch_ids, + const paddle::Tensor &decoder_tile_ids_per_batch, + const paddle::Tensor &decoder_num_blocks, + const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv, + paddle::Tensor &res, + const paddle::optional &rotary_embs, + const paddle::optional &attn_mask, + const paddle::optional &qkv_bias, + const paddle::optional &qkv_out_scales, + const paddle::optional &cache_k_quant_scales, + const paddle::optional &cache_v_quant_scales, + const paddle::optional &cache_k_dequant_scales, + const paddle::optional &cache_v_dequant_scales, + const paddle::optional &cache_k_zp, + const paddle::optional &cache_v_zp, + const paddle::optional &out_linear_shifts, + const paddle::optional &out_linear_smooths, + const paddle::optional &mask_offset, + const paddle::optional &kv_signal_data, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps, + const std::string &compute_dtype, const std::string &cache_quant_type_str, + const bool use_neox_rotary_style, const bool rope_3d, + const int max_input_length, const float quant_max_bound, + const float quant_min_bound, const float out_linear_in_scale, + const int encoder_block_shape_q, const int decoder_block_shape_q, + const int max_partition_size, const int encoder_max_partition_size, + const int speculate_max_draft_token_num, const bool causal, + const bool speculate_decoder); + std::vector GQARopeWriteCacheKernel( const paddle::Tensor &qkv, const paddle::Tensor &key_cache, const paddle::Tensor &value_cache, const paddle::Tensor &cu_seqlens_q, @@ -829,6 +872,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { * append_attention */ m.def("append_attention", &AppendAttention, "append attention function"); + m.def("append_attention_with_output", &AppendAttentionWithOutput, "append attention with output function"); /** * gqa_rope_write_cache.cu * gqa_rope_write_cache diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index ea6bdd6ab6..146661790a 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -24,6 +24,7 @@ from fastdeploy.model_executor.layers.attention.ops import ( append_attention, + append_attention_with_output, get_block_shape_and_split_kv_block, init_kv_signal_per_query, init_signal_layerwise, @@ -122,6 +123,7 @@ def __init__( fd_config.parallel_config.expert_parallel_rank = 0 self.rank, self.device_id = init_rank_and_device_id(fd_config) + self.use_output = fd_config.graph_opt_config.full_cuda_graph def init_attention_metadata(self, forward_meta: ForwardMeta): """Initialize attntion metadata hence all layers in the forward pass can reuse it.""" @@ -229,58 +231,149 @@ def forward_mixed( layer.layer_id + self.start_layer_index, ) - res = append_attention( - qkv, - forward_meta.caches[2 * layer.layer_id], - forward_meta.caches[2 * layer.layer_id + 1], - forward_meta.seq_lens_encoder, - forward_meta.seq_lens_decoder, - forward_meta.seq_lens_this_time, - forward_meta.batch_id_per_token, - forward_meta.cu_seqlens_q, - metadata.block_tables, - metadata.encoder_batch_ids, - metadata.encoder_tile_ids_per_batch, - metadata.encoder_num_blocks, - metadata.kv_batch_ids, - metadata.kv_tile_ids_per_batch, - metadata.kv_num_blocks, - forward_meta.decoder_batch_ids, - forward_meta.decoder_tile_ids_per_batch, - forward_meta.decoder_num_blocks_cpu, - forward_meta.max_len_tensor_cpu, - metadata.max_len_kv, - metadata.rotary_embs, - metadata.attn_mask, - layer.qkv_bias, - layer.qkv_scale, - getattr(layer, "cache_k_scale", None), - getattr(layer, "cache_v_scale", None), - getattr(layer, "cache_k_out_scale", None), - getattr(layer, "cache_v_out_scale", None), - getattr(layer, "cache_k_zp", None), - getattr(layer, "cache_v_zp", None), - layer.linear_shift, - layer.linear_smooth, - metadata.mask_offset, - metadata.kv_signal_data_list[layer.layer_id], - getattr(layer, "q_norm_weight", None), - getattr(layer, "k_norm_weight", None), - getattr(layer, "rms_norm_eps", 1e-6), - metadata._fuse_kernel_compute_dtype, - getattr(layer, "cache_quant_type_str", "none"), - layer.use_neox_rotary_style, - self.rope_3d, - self.max_seq_len, - getattr(layer, "quant_max_bound", 0.0), - getattr(layer, "quant_min_bound", 0.0), - getattr(layer, "out_scale", -1.0), - self.encoder_block_shape_q, - self.decoder_block_shape_q, - metadata.max_partition_size, - metadata.encoder_max_partition_size, - self.speculate_max_draft_token_num + 1, - self.causal, - self.speculative_method is not None, - )[0] + if self.use_output: + quant_max_bound = getattr(layer, "quant_max_bound", 0.0) + cache_quant_type = getattr(layer, "cache_quant_type_str", "none") + compute_type = metadata._fuse_kernel_compute_dtype + out_scale = getattr(layer, "out_scale", -1.0) + # 1. get output datatype + qkv_dtype = qkv.dtype + if qkv_dtype == paddle.float16: + D_type = paddle.float16 + elif qkv_dtype == paddle.bfloat16: + D_type = paddle.bfloat16 + elif qkv_dtype == paddle.int32: + if compute_type == "bf16": + D_type = paddle.bfloat16 + elif compute_type == "fp16": + D_type = paddle.float16 + else: + raise NotImplementedError("Only supported attr of qkv_type in ['float16', 'bfloat16'].") + else: + raise NotImplementedError("Only supported attr of qkv_type in ['float16', 'bfloat16', 'int32'].") + # 2.Extract related parameters + token_nums = qkv.shape[0] + head_dims = self.head_dim if cache_quant_type != "cache_int4_zp" else self.head_dim * 2 + q_num_heads = self.num_heads + # 3. generate output tensor of different dtypes + if out_scale > 0.0: + if abs(quant_max_bound - 127) < 0.000001: + res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="int8").to(qkv.place) + elif abs(quant_max_bound - 448) < 0.000001: + res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="float8_e4m3fn").to(qkv.place) + else: + raise NotImplementedError("Only supported attr of quant_max_bound in ['127', '448'].") + else: + res = paddle.empty([token_nums, q_num_heads * head_dims], dtype=D_type).to(qkv.place) + + append_attention_with_output( + qkv, + forward_meta.caches[2 * layer.layer_id], + forward_meta.caches[2 * layer.layer_id + 1], + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.batch_id_per_token, + forward_meta.cu_seqlens_q, + metadata.block_tables, + metadata.encoder_batch_ids, + metadata.encoder_tile_ids_per_batch, + metadata.encoder_num_blocks, + metadata.kv_batch_ids, + metadata.kv_tile_ids_per_batch, + metadata.kv_num_blocks, + forward_meta.decoder_batch_ids, + forward_meta.decoder_tile_ids_per_batch, + forward_meta.decoder_num_blocks_cpu, + forward_meta.max_len_tensor_cpu, + metadata.max_len_kv, + res, + metadata.rotary_embs, + metadata.attn_mask, + layer.qkv_bias, + layer.qkv_scale, + getattr(layer, "cache_k_scale", None), + getattr(layer, "cache_v_scale", None), + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + layer.linear_shift, + layer.linear_smooth, + metadata.mask_offset, + metadata.kv_signal_data_list[layer.layer_id], + getattr(layer, "q_norm_weight", None), + getattr(layer, "k_norm_weight", None), + getattr(layer, "rms_norm_eps", 1e-6), + metadata._fuse_kernel_compute_dtype, + getattr(layer, "cache_quant_type_str", "none"), + layer.use_neox_rotary_style, + self.rope_3d, + self.max_seq_len, + getattr(layer, "quant_max_bound", 0.0), + getattr(layer, "quant_min_bound", 0.0), + getattr(layer, "out_scale", -1.0), + self.encoder_block_shape_q, + self.decoder_block_shape_q, + metadata.max_partition_size, + metadata.encoder_max_partition_size, + self.speculate_max_draft_token_num + 1, + self.causal, + self.speculative_method is not None, + ) + else: + res = append_attention( + qkv, + forward_meta.caches[2 * layer.layer_id], + forward_meta.caches[2 * layer.layer_id + 1], + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.batch_id_per_token, + forward_meta.cu_seqlens_q, + metadata.block_tables, + metadata.encoder_batch_ids, + metadata.encoder_tile_ids_per_batch, + metadata.encoder_num_blocks, + metadata.kv_batch_ids, + metadata.kv_tile_ids_per_batch, + metadata.kv_num_blocks, + forward_meta.decoder_batch_ids, + forward_meta.decoder_tile_ids_per_batch, + forward_meta.decoder_num_blocks_cpu, + forward_meta.max_len_tensor_cpu, + metadata.max_len_kv, + metadata.rotary_embs, + metadata.attn_mask, + layer.qkv_bias, + layer.qkv_scale, + getattr(layer, "cache_k_scale", None), + getattr(layer, "cache_v_scale", None), + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + layer.linear_shift, + layer.linear_smooth, + metadata.mask_offset, + metadata.kv_signal_data_list[layer.layer_id], + getattr(layer, "q_norm_weight", None), + getattr(layer, "k_norm_weight", None), + getattr(layer, "rms_norm_eps", 1e-6), + metadata._fuse_kernel_compute_dtype, + getattr(layer, "cache_quant_type_str", "none"), + layer.use_neox_rotary_style, + self.rope_3d, + self.max_seq_len, + getattr(layer, "quant_max_bound", 0.0), + getattr(layer, "quant_min_bound", 0.0), + getattr(layer, "out_scale", -1.0), + self.encoder_block_shape_q, + self.decoder_block_shape_q, + metadata.max_partition_size, + metadata.encoder_max_partition_size, + self.speculate_max_draft_token_num + 1, + self.causal, + self.speculative_method is not None, + ) return res diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index ed92483932..bf63aae11f 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -377,7 +377,7 @@ def forward_mixed( self.speculate_max_draft_token_num + 1, self.causal, self.speculative_method is not None, - )[0] + ) if metadata.max_len_tensor_cpu[1] > 0: merge_prefill_decode_output( diff --git a/fastdeploy/model_executor/layers/attention/ops/__init__.py b/fastdeploy/model_executor/layers/attention/ops/__init__.py index f2f629d94d..caf8bcb9b5 100644 --- a/fastdeploy/model_executor/layers/attention/ops/__init__.py +++ b/fastdeploy/model_executor/layers/attention/ops/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. """ -from .append_attention import append_attention +from .append_attention import append_attention, append_attention_with_output from .get_block_shape_and_split_kv_block import get_block_shape_and_split_kv_block from .gqa_rope_write_cache import gqa_rope_write_cache from .init_kv_signal_per_query import init_kv_signal_per_query @@ -25,6 +25,7 @@ __all__ = [ "get_block_shape_and_split_kv_block", "append_attention", + "append_attention_with_output", "open_shm_and_get_meta_signal", "init_signal_layerwise", "gqa_rope_write_cache", diff --git a/fastdeploy/model_executor/layers/attention/ops/append_attention.py b/fastdeploy/model_executor/layers/attention/ops/append_attention.py index bbcf8a1e93..7cf9636876 100644 --- a/fastdeploy/model_executor/layers/attention/ops/append_attention.py +++ b/fastdeploy/model_executor/layers/attention/ops/append_attention.py @@ -24,6 +24,9 @@ from fastdeploy.model_executor.ops.gpu import ( append_attention as append_attention_gpu, ) + from fastdeploy.model_executor.ops.gpu import ( + append_attention_with_output as append_attention_with_output_gpu, + ) def append_attention( @@ -141,3 +144,124 @@ def append_attention( return out else: raise NotImplementedError + + +# TODO: (mengyuan) merge w/o output version append attention after +# finishing developing sub-graph cudagraph capture to reduce +# compilation volume +def append_attention_with_output( + qkv: paddle.Tensor, + key_cache: paddle.Tensor, + value_cache: paddle.Tensor, + seq_lens_encoder: paddle.Tensor, + seq_lens_decoder: paddle.Tensor, + seq_lens_this_time: paddle.Tensor, + batch_id_per_token: paddle.Tensor, + cu_seqlens_q: paddle.Tensor, + block_tables: paddle.Tensor, + encoder_batch_ids: paddle.Tensor, + encoder_tile_ids_per_batch: paddle.Tensor, + encoder_num_blocks: paddle.Tensor, + kv_batch_ids: paddle.Tensor, + kv_tile_ids_per_batch: paddle.Tensor, + kv_num_blocks: paddle.Tensor, + decoder_batch_ids: paddle.Tensor, + decoder_tile_ids_per_batch: paddle.Tensor, + decoder_num_blocks: paddle.Tensor, + set_max_lengths: paddle.Tensor, + max_len_kv: paddle.Tensor, + out: paddle.tensor, # attention output + rotary_embs: Optional[paddle.Tensor] = None, + attn_mask: Optional[paddle.Tensor] = None, + qkv_bias: Optional[paddle.Tensor] = None, + qkv_scale: Optional[paddle.Tensor] = None, + k_quant_scale: Optional[paddle.Tensor] = None, + v_quant_scale: Optional[paddle.Tensor] = None, + k_dequant_scale: Optional[paddle.Tensor] = None, + v_dequant_scale: Optional[paddle.Tensor] = None, + cache_k_zp: Optional[paddle.Tensor] = None, + cache_v_zp: Optional[paddle.Tensor] = None, + linear_shift: Optional[paddle.Tensor] = None, + linear_smooth: Optional[paddle.Tensor] = None, + mask_offset: Optional[paddle.Tensor] = None, + kv_signal_data: Optional[paddle.Tensor] = None, + q_norm_weight: Optional[paddle.Tensor] = None, + k_norm_weight: Optional[paddle.Tensor] = None, + rms_norm_eps: float = 1e-6, + compute_type: str = "bf16", + cache_quant_type: str = "none", + use_neox_rotary_style: bool = False, + rope_3d: bool = False, + max_input_length: int = 0, + quant_max_bound: float = 0.0, + quant_min_bound: float = 0.0, + out_linear_in_scale: float = -1.0, + encoder_block_shape_q: int = 64, + decoder_block_shape_q: int = 16, + max_partition_size: int = 32768, + encoder_max_partition_size: int = 32768, + speculate_max_draft_token_num: int = 1, + causal: bool = True, + speculate_decoder: bool = False, +) -> None: + """ + append_attention + """ + if current_platform.is_cuda(): + append_attention_with_output_gpu( + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks, + set_max_lengths, + max_len_kv, + out, + rotary_embs, + attn_mask, + qkv_bias, + qkv_scale, + k_quant_scale, + v_quant_scale, + k_dequant_scale, + v_dequant_scale, + cache_k_zp, + cache_v_zp, + linear_shift, + linear_smooth, + mask_offset, + kv_signal_data, + q_norm_weight, + k_norm_weight, + rms_norm_eps, + compute_type, + cache_quant_type, + use_neox_rotary_style, + rope_3d, + max_input_length, + quant_max_bound, + quant_min_bound, + out_linear_in_scale, + encoder_block_shape_q, + decoder_block_shape_q, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + speculate_decoder, + ) + else: + raise NotImplementedError diff --git a/test/layers/test_append_attention.py b/test/layers/test_append_attention.py index e3e4de158c..5d45c58101 100644 --- a/test/layers/test_append_attention.py +++ b/test/layers/test_append_attention.py @@ -532,7 +532,7 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask speculate_max_draft_token_num + 1, # speculate_max_draft_token_num True, # causal False, # speculate_decoder - )[0] + ) paddle.device.synchronize() end_time = time.time() print(f"[append-attn ut] cost_time:{(end_time - start_time) / RUN_TIME * 1000}ms") diff --git a/test/layers/test_append_attention_with_output.py b/test/layers/test_append_attention_with_output.py new file mode 100644 index 0000000000..db93c87da6 --- /dev/null +++ b/test/layers/test_append_attention_with_output.py @@ -0,0 +1,623 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import unittest + +import numpy as np +import paddle +from paddle.incubate.nn.functional import fused_rms_norm + +paddle.seed(10) + + +class RopeEmbedding: + def __init__(self, use_neox_rotary_style=False): + self.use_neox_rotary_style = use_neox_rotary_style + self.base = 10000 + + def get_neox_style_position_embedding(self, position_ids, head_dim): + bsz, max_seq_len = position_ids.shape[:2] + rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim), dtype="float32") + inv_freq = self.base ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim) + + # shape: [B, S, D/2] + freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) + # shape: [B, S, 1, D] + emb = paddle.concat([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, 1, head_dim)) + + rot_emb[0] = paddle.cos(emb) + rot_emb[1] = paddle.sin(emb) + return rot_emb + + def get_rotary_position_embedding(self, position_ids, head_dim): + bsz, max_seq_len = position_ids.shape[:2] + rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim // 2), dtype="float32") + inv_freq = self.base ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim) + + # shape: [B, S, D/2] + freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) + # shape: [B, S, D/2] + emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, head_dim // 2)) + # shape: [B, S, 1, D] + emb = paddle.unsqueeze(emb, 2) + + rot_emb[0] = paddle.cos(emb) + rot_emb[1] = paddle.sin(emb) + return rot_emb + + def _apply_rope(self, rotary_emb, q, k, v=None, causal=False): + # sin [sequence_length, embed_size_per_head//2] + # cos [sequence_length, embed_size_per_head//2] + # sin, cos = paddle.chunk(rp, 2, axis=-1) + seq, head_dim = q.shape[2], q.shape[3] + cos, sin = paddle.chunk(rotary_emb, 2, axis=0) + cos = paddle.squeeze(cos, axis=0).transpose([0, 2, 1, 3])[:, :, :seq, :] + sin = paddle.squeeze(sin, axis=0).transpose([0, 2, 1, 3])[:, :, :seq, :] + # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] + + if self.use_neox_rotary_style: + sin_pos = sin + cos_pos = cos + # NeoX Stype:前后半部分分块旋转 + rotate_half_q = paddle.reshape( + paddle.stack( + [ + -q[:, :, :, q.shape[-1] // 2 :], + q[:, :, :, : q.shape[-1] // 2], + ], + axis=-1, + ), + paddle.shape(q), + ) + rotate_half_k = paddle.reshape( + paddle.stack( + [ + -k[:, :, :, k.shape[-1] // 2 :], + k[:, :, :, : k.shape[-1] // 2], + ], + axis=-1, + ), + paddle.shape(k), + ) + else: + # import pdb;pdb.set_trace() + sin_pos = paddle.reshape(paddle.stack([sin, sin], axis=-1), [1, 1, seq, head_dim]) + # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] + cos_pos = paddle.reshape(paddle.stack([cos, cos], axis=-1), [1, 1, seq, head_dim]) + # GPT Stype:奇偶位置分块旋转 + rotate_half_q = paddle.reshape( + paddle.stack([-q[:, :, :, 1::2], q[:, :, :, 0::2]], axis=-1), + paddle.shape(q), + ) + rotate_half_k = paddle.reshape( + paddle.stack([-k[:, :, :, 1::2], k[:, :, :, 0::2]], axis=-1), + paddle.shape(k), + ) + + query = paddle.add(paddle.multiply(q, cos_pos), paddle.multiply(rotate_half_q, sin_pos)) + + key = paddle.add(paddle.multiply(k, cos_pos), paddle.multiply(rotate_half_k, sin_pos)) + + return paddle.cast(query, q.dtype), paddle.cast(key, k.dtype) + + +def create_attn_mask( + mask_type, + batch_size, + seq_lens, + pre_cache_length=0, +): + max_seq_len = max(seq_lens) + mask = paddle.zeros( + # [batch_size, 1, max_seq_len, max_seq_len + pre_cache_length], + [batch_size, 1, max_seq_len, max_seq_len], + dtype=mask_type, + ) + mask[:, :, :, :pre_cache_length] = 1 + for i in range(batch_size): + seq_len = seq_lens[i] + mask[i, 0, :seq_len, :seq_len] = ( + paddle.tril(paddle.ones(shape=(seq_len, seq_len), dtype=mask_type)) - 1 + ) * 1e4 + return mask + + +def block_cache_to_naive_cache(cache_k, cache_v, bsz, block_tables, cache_seq_len): + _, num_head, blocksize, dim_head = cache_k.shape + out_cache_k = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_k.dtype) + out_cache_v = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_v.dtype) + for i in range(bsz): + for j in range(cache_seq_len): + out_cache_k[i, :, j, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :] + out_cache_v[i, :, j, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :] + return out_cache_k, out_cache_v + + +def naive_attention_impl( + query, + key, + value, + cache_k=None, + cache_v=None, + pre_cache_k=None, + pre_cache_v=None, + mask=None, + scale=1.0, + cache_k_dequant_scales=None, + cache_v_dequant_scales=None, + use_cachekv_int8="None", + q_norm_weight=None, + k_norm_weight=None, +): + batch = query.shape[0] + heads = query.shape[1] + seq_len = query.shape[2] + head_dim = query.shape[3] + kv_head = key.shape[1] + + key = key.reshape([batch, kv_head, 1, seq_len, head_dim]) + key = paddle.tile(key, [1, 1, heads // kv_head, 1, 1]) + key = key.reshape([batch, heads, seq_len, head_dim]) + + if cache_k is not None: + cache_k = cache_k.reshape([batch, kv_head, 1, -1, head_dim]) + cache_k = paddle.tile(cache_k, [1, 1, heads // kv_head, 1, 1]) + cache_k = cache_k.reshape([batch, heads, -1, head_dim]) + key = paddle.concat([cache_k, key], axis=2) + + value = value.reshape([batch, kv_head, 1, seq_len, head_dim]) + value = paddle.tile(value, [1, 1, heads // kv_head, 1, 1]) + value = value.reshape([batch, heads, seq_len, head_dim]) + + if cache_v is not None: + cache_v = cache_v.reshape([batch, kv_head, 1, -1, head_dim]) + cache_v = paddle.tile(cache_v, [1, 1, heads // kv_head, 1, 1]) + cache_v = cache_v.reshape([batch, heads, -1, head_dim]) + value = paddle.concat([cache_v, value], axis=2) + + qk_res = paddle.matmul(query, key, transpose_y=True) + attention = qk_res * scale + if mask is not None: + attention = attention + mask + softmax_result = paddle.nn.functional.softmax(attention, -1) + result = paddle.matmul(paddle.cast(softmax_result, dtype=value.dtype), value) + return result + + +def get_padding_offset(bsz, max_seq_len, seq_lens_this_time): + cum_offsets_now = paddle.cumsum(max_seq_len - seq_lens_this_time) + cum_offsets = paddle.zeros(shape=(bsz + 1), dtype="int32") + cum_offsets[1:] = cum_offsets_now + token_num = paddle.sum(seq_lens_this_time) + padding_offsets = paddle.zeros(shape=(token_num), dtype="int32") + cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32") + cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32") + for i in range(bsz): + seq_len_now = seq_lens_this_time[i] + cum_offset = cum_offsets[i] + for j in range(seq_len_now): + padding_offsets[i * max_seq_len - cum_offset + j] = cum_offset + cum_seq_len = (i + 1) * max_seq_len - cum_offsets[i + 1] + cu_seqlens_q[i + 1] = cum_seq_len + cu_seqlens_k[i + 1] = cum_seq_len + return padding_offsets, cum_offsets[:-1], cu_seqlens_q, cu_seqlens_k + + +def remove_padding(seq_lens, cu_seq_lens, inputs, token_num): + bsz, num_head, seq_len, dim_head = inputs.shape + output = paddle.zeros(shape=[token_num, num_head * dim_head], dtype=inputs.dtype) + inputs = inputs.transpose([0, 2, 1, 3]).reshape([bsz, seq_len, -1]) + for i in range(bsz): + seq_len_now = seq_lens[i] + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + output[start_idx:end_idx, :] = inputs[i, :seq_len_now, :] + return output + + +def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, dim_head, place, dtype): + query = np.random.random([bs, q_num_head, seq_len, dim_head]) / 10 + q = paddle.to_tensor(query, place=place, dtype=dtype, stop_gradient=False) + key = np.random.random([bs, kv_num_head, seq_len, dim_head]) / 10 + k = paddle.to_tensor(key, place=place, dtype=dtype, stop_gradient=False) + value = np.random.random([bs, kv_num_head, seq_len, dim_head]) / 10 + v = paddle.to_tensor(value, place=place, dtype=dtype, stop_gradient=False) + token_num = bs * seq_len + + qkv = paddle.concat( + [ + q.transpose([0, 2, 1, 3]).reshape([token_num, q_num_head * dim_head]), + k.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * dim_head]), + v.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * dim_head]), + ], + axis=1, + ).reshape([token_num, -1]) + return q, k, v, qkv + + +def apply_qk_norm(head_dim, dtype, q, k): + q_norm_weight = np.random.random([head_dim]) / 10 + k_norm_weight = np.random.random([head_dim]) / 10 + q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype=dtype) + k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype=dtype) + print("q:", q.shape) + print("k:", k.shape) + bs, q_num_head, seq_len, dim_head = q.shape + _, kv_num_head, _, _ = k.shape + + q = q.reshape([-1, head_dim]) + k = k.reshape([-1, head_dim]) + print("q:", q) + q = fused_rms_norm(q, q_norm_weight_tensor, None, 1e-5)[0] + print("q after norm:", q) + k = fused_rms_norm(k, k_norm_weight_tensor, None, 1e-5)[0] + q = q.reshape([-1, q_num_head, seq_len, dim_head]) + k = k.reshape([-1, kv_num_head, seq_len, dim_head]) + return q, k, q_norm_weight_tensor, k_norm_weight_tensor + + +def split_query_by_phase( + query, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + q_dim, + k_dim, + v_dim, +): + """ + 将 query 拆分为 encoder 和 decoder 的 Q/K/V。 + """ + + batch = seq_lens_encoder.shape[0] + max_seq = query.shape[0] // batch + + # 还原 query 为 [batch, seq, dim] + total_dim = q_dim + k_dim + v_dim + query = paddle.reshape(query, [batch, max_seq, total_dim]) + + # 计算 mask,表示该 batch 是否是 encoder/decoder + is_encoder = (seq_lens_encoder > 0).astype("bool").reshape([-1]) # [batch] + is_decoder = (seq_lens_decoder > 0).astype("bool").reshape([-1]) # [batch] + + # 准备输出列表 + enc_qs, enc_ks, enc_vs = [], [], [] + dec_qs, dec_ks, dec_vs = [], [], [] + + for i in range(batch): + real_len = int(seq_lens_this_time[i]) # 当前 batch 的有效长度 + cur_query = query[i, :real_len, :] # [seq_i, q+k+v] + + q, k, v = paddle.split(cur_query, [q_dim, k_dim, v_dim], axis=-1) + + if is_encoder[i]: + enc_qs.append(q) + enc_ks.append(k) + enc_vs.append(v) + elif is_decoder[i]: + dec_qs.append(q) + dec_ks.append(k) + dec_vs.append(v) + + if enc_qs: + enc_q = paddle.concat(enc_qs, axis=0) + enc_k = paddle.concat(enc_ks, axis=0) + enc_v = paddle.concat(enc_vs, axis=0) + else: + enc_q = enc_k = enc_v = paddle.zeros([0, q_dim], dtype=query.dtype) + + if dec_qs: + dec_q = paddle.concat(dec_qs, axis=0) + dec_k = paddle.concat(dec_ks, axis=0) + dec_v = paddle.concat(dec_vs, axis=0) + else: + dec_q = dec_k = dec_v = paddle.zeros([0, q_dim], dtype=query.dtype) + + return (enc_q, enc_k, enc_v), (dec_q, dec_k, dec_v) + + +class TestAppendGroupQueryAttnWithRope(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.name = "TestAppendGroupQueryAttnWithRope" + self.place = paddle.CUDAPlace(0) + self.batch_size = 1 + self.q_num_head = 12 + self.kv_num_head = 2 + self.seq_len = 64 + self.max_dec_len = 64 + self.dim_head = 128 + self.q_hid_dim = self.q_num_head * self.dim_head + self.kv_hid_dim = self.kv_num_head * self.dim_head + self.blocksize = 64 + self.use_neox_rotary_style = False + # max_seq_len = self.seq_len + self.max_dec_len + self.max_seq_len = self.seq_len + self.max_dec_len + self.softmax_scale = self.dim_head**-0.5 + self.rope_theta = 10000 + self.dtype = "float16" + self.use_qk_norm = True + self.init_tensor() + + def init_tensor(self): + self.block_num_per_seq = (self.seq_len + self.max_dec_len + self.blocksize - 1) // self.blocksize + self.rope = RopeEmbedding(self.use_neox_rotary_style) + self.max_block_num = self.block_num_per_seq * self.batch_size + self.free_list = list(range(self.max_block_num - 1, -1, -1)) + + self.seq_lens_enc = [ + self.seq_len, + ] * self.batch_size + self.seq_lens_dec = [ + 0, + ] * self.batch_size + self.max_enc_len_this_time = max(self.seq_lens_enc) + self.max_dec_len_this_time = max(self.seq_lens_dec) + self.seq_lens_encoder = paddle.to_tensor( + self.seq_lens_enc, + "int32", + ) + self.seq_lens_decoder = paddle.to_tensor( + self.seq_lens_dec, + "int32", + ) + self.max_enc_len_this_time = paddle.to_tensor([self.max_enc_len_this_time], "int32", place=paddle.CPUPlace()) + self.max_dec_len_this_time = paddle.to_tensor([self.max_dec_len_this_time], "int32", place=paddle.CPUPlace()) + self.seq_lens_this_time = self.seq_lens_encoder + + self.decoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32") + self.decoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32") + self.decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory() + self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu() + + self.cache_shape = ( + self.max_block_num, + self.kv_num_head, + self.blocksize, + self.dim_head, + ) + + self.scale = 1.0 / np.sqrt(self.dim_head) + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.block_tables = paddle.zeros(shape=(self.batch_size, self.block_num_per_seq), dtype="int32") + for i in range(self.batch_size): + need_block_num = (self.seq_len + self.max_dec_len + self.blocksize - 1) // self.blocksize + for j in range(need_block_num): + self.block_tables[i, j] = self.free_list.pop() + ( + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + ) = get_padding_offset(self.batch_size, self.seq_len, self.seq_lens_this_time) + self.token_num = self.padding_offset.shape[0] + + def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask=None): + paddle.disable_static() + self.token_num = self.seq_len * self.batch_size + q, k, v, qkv = get_qkv_and_qkv_concat_tensor( + self.batch_size, + self.q_num_head, + self.kv_num_head, + self.seq_len, + self.dim_head, + self.place, + self.dtype, + ) + + q, k = self.rope._apply_rope(self.rope_emb, q, k, causal=True) + if self.use_qk_norm: + q, k, q_norm_weight, k_norm_weight = apply_qk_norm(self.dim_head, self.dtype, q, k) + else: + q_norm_weight = None + k_norm_weight = None + out_ = naive_attention_impl( + q, + k, + v, + naive_cache_k, + naive_cache_v, + None, + None, + attn_mask, + self.scale, + ) + out_ = remove_padding(self.seq_lens_this_time, self.cu_seqlens_q, out_, self.token_num) + speculate_max_draft_token_num = 1 + from fastdeploy.model_executor.layers.attention.ops import ( + append_attention_with_output, + get_block_shape_and_split_kv_block, + ) + + ( + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + max_len_kv, + ) = get_block_shape_and_split_kv_block( + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.decoder_batch_ids, + self.decoder_tile_ids_per_batch, + self.decoder_num_blocks_cpu, + self.max_len_tensor_cpu, + 64, + 12, + (self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head, + self.blocksize, + speculate_max_draft_token_num + 1, + ) + + # Warm up + WARM_UP = 1 + RUN_TIME = 2 + out = paddle.zeros((qkv.shape[0], self.q_hid_dim), dtype=q.dtype).to(q.place) + for i in range(WARM_UP + RUN_TIME): + if i == WARM_UP: + paddle.device.synchronize() + start_time = time.time() + append_attention_with_output( + qkv, + self.cache_k, + self.cache_v, + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.padding_offset, + self.cum_offset, + self.block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + self.decoder_batch_ids, + self.decoder_tile_ids_per_batch, + self.decoder_num_blocks_cpu, + self.max_len_tensor_cpu, + max_len_kv, + out, # attention output + self.rope_emb, # rope_emb + None, # attn_mask + None, # qkv_bias + None, # qkv_out_scales + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # linear_shift + None, # linear_smooth + None, # kv_signal_data + q_norm_weight, # q_norm_weight + k_norm_weight, # k_norm_weight + 1e-6, + "fp16", + "none", # cache_quant_type + self.use_neox_rotary_style, + False, + self.max_seq_len, + 0.0, # quant_min_bound + 0.0, # quant_max_bound + -1, # out_linear_in_scale + 64, # encoder_block_shape_q + 16, # decoder_block_shape_q + 32768, # max_partition_size + 32768, # encoder_max_partition_size + speculate_max_draft_token_num + 1, # speculate_max_draft_token_num + True, # causal + False, # speculate_decoder + ) + paddle.device.synchronize() + end_time = time.time() + print(f"[append-attn ut] cost_time:{(end_time - start_time) / RUN_TIME * 1000}ms") + naive_cache_k, naive_cache_v = block_cache_to_naive_cache( + self.cache_k, + self.cache_v, + self.batch_size, + self.block_tables, + self.seq_len, + ) + np.testing.assert_allclose( + out.numpy(), + out_.numpy(), + rtol=1e-02, + atol=1e-02, + ) + + def test_all(self): + tmp_position_ids = paddle.arange(self.seq_len + self.max_dec_len).reshape((1, -1)) + # appendattn 传的是最大maxseq + if self.use_neox_rotary_style: + self.rope_emb = self.rope.get_neox_style_position_embedding(tmp_position_ids, self.dim_head) + else: + self.rope_emb = self.rope.get_rotary_position_embedding(tmp_position_ids, self.dim_head) + self.attention_mask = create_attn_mask( + self.dtype, + self.batch_size, + [ + self.seq_len, + ] + * self.batch_size, + ) + # encoder + # self.seq_lens_encoder,self.seq_lens_decoder,self.max_enc_len_this_time,self.max_dec_len_this_time=get_encoder_decoder_len(self.batch_size,self.seq_len) + self.seq_lens_this_time = self.seq_lens_encoder + self.cmp_append_attention(attn_mask=self.attention_mask) + naive_cache_k, naive_cache_v = block_cache_to_naive_cache( + self.cache_k, + self.cache_v, + self.batch_size, + self.block_tables, + self.seq_len, + ) + # decoder + self.seq_lens_decoder[:] = self.seq_lens_encoder + self.seq_lens_encoder[:] = 0 + self.seq_lens_this_time[:] = 1 + self.seq_lens_enc = [ + 0, + ] * self.batch_size + self.seq_lens_dec = [ + self.seq_len, + ] * self.batch_size + self.max_enc_len_this_time = max(self.seq_lens_enc) + self.max_dec_len_this_time = max(self.seq_lens_dec) + self.max_enc_len_this_time = paddle.to_tensor([self.max_enc_len_this_time], "int32", place=paddle.CPUPlace()) + self.max_dec_len_this_time = paddle.to_tensor([self.max_dec_len_this_time], "int32", place=paddle.CPUPlace()) + + self.seq_len = 1 + ( + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + ) = get_padding_offset(self.batch_size, 1, self.seq_lens_this_time) + self.cmp_append_attention(naive_cache_k, naive_cache_v, None) + + +class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope): + def setUp(self): + paddle.disable_static() + self.name = "TestAppendGroupQueryAttnWithRope" + self.place = paddle.CUDAPlace(0) + self.batch_size = 1 + self.q_num_head = 12 + self.kv_num_head = 2 + self.seq_len = 64 + self.max_dec_len = 64 + self.dim_head = 128 + self.q_hid_dim = self.q_num_head * self.dim_head + self.kv_hid_dim = self.kv_num_head * self.dim_head + self.blocksize = 64 + self.use_neox_rotary_style = True + # max_seq_len = self.seq_len + self.max_dec_len + self.max_seq_len = self.seq_len + self.max_dec_len + self.softmax_scale = self.dim_head**-0.5 + self.rope_theta = 10000 + self.dtype = "float16" + self.use_qk_norm = False + self.init_tensor() + + +if __name__ == "__main__": + unittest.main()