Skip to content

Add with_output version AppendAttention #3302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
435 changes: 397 additions & 38 deletions custom_ops/gpu_ops/append_attention.cu

Large diffs are not rendered by default.

43 changes: 43 additions & 0 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,48 @@ std::vector<paddle::Tensor> AppendAttention(
const int speculate_max_draft_token_num, const bool causal,
const bool speculate_decoder);

void AppendAttentionWithOutput(
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_tables, const paddle::Tensor &encoder_batch_ids,
const paddle::Tensor &encoder_tile_ids_per_batch,
const paddle::Tensor &encoder_num_blocks,
const paddle::Tensor &kv_batch_ids,
const paddle::Tensor &kv_tile_ids_per_batch,
const paddle::Tensor &kv_num_blocks,
const paddle::Tensor &decoder_batch_ids,
const paddle::Tensor &decoder_tile_ids_per_batch,
const paddle::Tensor &decoder_num_blocks,
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
paddle::Tensor &res,
const paddle::optional<paddle::Tensor> &rotary_embs,
const paddle::optional<paddle::Tensor> &attn_mask,
const paddle::optional<paddle::Tensor> &qkv_bias,
const paddle::optional<paddle::Tensor> &qkv_out_scales,
const paddle::optional<paddle::Tensor> &cache_k_quant_scales,
const paddle::optional<paddle::Tensor> &cache_v_quant_scales,
const paddle::optional<paddle::Tensor> &cache_k_dequant_scales,
const paddle::optional<paddle::Tensor> &cache_v_dequant_scales,
const paddle::optional<paddle::Tensor> &cache_k_zp,
const paddle::optional<paddle::Tensor> &cache_v_zp,
const paddle::optional<paddle::Tensor> &out_linear_shifts,
const paddle::optional<paddle::Tensor> &out_linear_smooths,
const paddle::optional<paddle::Tensor> &kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string &compute_dtype, const std::string &cache_quant_type_str,
const bool use_neox_rotary_style, const bool rope_3d,
const int max_input_length, const float quant_max_bound,
const float quant_min_bound, const float out_linear_in_scale,
const int encoder_block_shape_q, const int decoder_block_shape_q,
const int max_partition_size, const int encoder_max_partition_size,
const int speculate_max_draft_token_num, const bool causal,
const bool speculate_decoder);

std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
const paddle::Tensor &value_cache, const paddle::Tensor &cu_seqlens_q,
Expand Down Expand Up @@ -828,6 +870,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
* append_attention
*/
m.def("append_attention", &AppendAttention, "append attention function");
m.def("append_attention_with_output", &AppendAttentionWithOutput, "append attention with output function");
/**
* gqa_rope_write_cache.cu
* gqa_rope_write_cache
Expand Down
198 changes: 145 additions & 53 deletions fastdeploy/model_executor/layers/attention/append_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from fastdeploy.model_executor.layers.attention.ops import (
append_attention,
append_attention_with_output,
get_block_shape_and_split_kv_block,
init_kv_signal_per_query,
init_signal_layerwise,
Expand Down Expand Up @@ -121,6 +122,7 @@ def __init__(
fd_config.parallel_config.expert_parallel_rank = 0

self.rank, self.device_id = init_rank_and_device_id(fd_config)
self.use_output = fd_config.graph_opt_config.full_cuda_graph

def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
Expand Down Expand Up @@ -228,57 +230,147 @@ def forward_mixed(
layer.layer_id + self.start_layer_index,
)

res = append_attention(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
metadata.max_len_kv,
metadata.rotary_embs,
metadata.attn_mask,
layer.qkv_bias,
layer.qkv_scale,
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),
getattr(layer, "cache_v_zp", None),
layer.linear_shift,
layer.linear_smooth,
metadata.kv_signal_data_list[layer.layer_id],
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
getattr(layer, "rms_norm_eps", 1e-6),
metadata._fuse_kernel_compute_dtype,
getattr(layer, "cache_quant_type_str", "none"),
layer.use_neox_rotary_style,
self.rope_3d,
self.max_seq_len,
getattr(layer, "quant_max_bound", 0.0),
getattr(layer, "quant_min_bound", 0.0),
getattr(layer, "out_scale", -1.0),
self.encoder_block_shape_q,
self.decoder_block_shape_q,
metadata.max_partition_size,
metadata.encoder_max_partition_size,
self.speculate_max_draft_token_num + 1,
self.causal,
self.speculative_method is not None,
)[0]
if self.use_output:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

只修改这个文件还不行,全局搜一下所有调用这个append_attention的地方

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

quant_max_bound = getattr(layer, "quant_max_bound", 0.0)
cache_quant_type = getattr(layer, "cache_quant_type_str", "none")
compute_type = metadata._fuse_kernel_compute_dtype
out_scale = getattr(layer, "out_scale", -1.0)
# 1. get output datatype
qkv_dtype = qkv.dtype
if qkv_dtype == paddle.float16:
D_type = paddle.float16
elif qkv_dtype == paddle.bfloat16:
D_type = paddle.bfloat16
elif qkv_dtype == paddle.int32:
if compute_type == "bf16":
D_type = paddle.bfloat16
elif compute_type == "fp16":
D_type = paddle.float16
else:
raise NotImplementedError("Only supported attr of qkv_type in ['float16', 'bfloat16'].")
else:
raise NotImplementedError("Only supported attr of qkv_type in ['float16', 'bfloat16', 'int32'].")
# 2.Extract related parameters
token_nums = qkv.shape[0]
head_dims = self.head_dim if cache_quant_type != "cache_int4_zp" else self.head_dim * 2
q_num_heads = self.num_heads
# 3. generate output tensor of different dtypes
if out_scale > 0.0:
if abs(quant_max_bound - 127) < 0.000001:
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="int8").to(qkv.place)
elif abs(quant_max_bound - 448) < 0.000001:
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="float8_e4m3fn").to(qkv.place)
else:
raise NotImplementedError("Only supported attr of quant_max_bound in ['127', '448'].")
else:
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype=D_type).to(qkv.place)

append_attention_with_output(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
metadata.max_len_kv,
res,
metadata.rotary_embs,
metadata.attn_mask,
layer.qkv_bias,
layer.qkv_scale,
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),
getattr(layer, "cache_v_zp", None),
layer.linear_shift,
layer.linear_smooth,
metadata.kv_signal_data_list[layer.layer_id],
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
getattr(layer, "rms_norm_eps", 1e-6),
metadata._fuse_kernel_compute_dtype,
getattr(layer, "cache_quant_type_str", "none"),
layer.use_neox_rotary_style,
self.rope_3d,
self.max_seq_len,
getattr(layer, "quant_max_bound", 0.0),
getattr(layer, "quant_min_bound", 0.0),
getattr(layer, "out_scale", -1.0),
self.encoder_block_shape_q,
self.decoder_block_shape_q,
metadata.max_partition_size,
metadata.encoder_max_partition_size,
self.speculate_max_draft_token_num + 1,
self.causal,
self.speculative_method is not None,
)
else:
res = append_attention(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
metadata.max_len_kv,
metadata.rotary_embs,
metadata.attn_mask,
layer.qkv_bias,
layer.qkv_scale,
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),
getattr(layer, "cache_v_zp", None),
layer.linear_shift,
layer.linear_smooth,
metadata.kv_signal_data_list[layer.layer_id],
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
getattr(layer, "rms_norm_eps", 1e-6),
metadata._fuse_kernel_compute_dtype,
getattr(layer, "cache_quant_type_str", "none"),
layer.use_neox_rotary_style,
self.rope_3d,
self.max_seq_len,
getattr(layer, "quant_max_bound", 0.0),
getattr(layer, "quant_min_bound", 0.0),
getattr(layer, "out_scale", -1.0),
self.encoder_block_shape_q,
self.decoder_block_shape_q,
metadata.max_partition_size,
metadata.encoder_max_partition_size,
self.speculate_max_draft_token_num + 1,
self.causal,
self.speculative_method is not None,
)
return res
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def forward_mixed(
self.speculate_max_draft_token_num + 1,
self.causal,
self.speculative_method is not None,
)[0]
)

if metadata.max_len_tensor_cpu[1] > 0:
merge_prefill_decode_output(
Expand Down
3 changes: 2 additions & 1 deletion fastdeploy/model_executor/layers/attention/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""

from .append_attention import append_attention
from .append_attention import append_attention, append_attention_with_output
from .get_block_shape_and_split_kv_block import get_block_shape_and_split_kv_block
from .gqa_rope_write_cache import gqa_rope_write_cache
from .init_kv_signal_per_query import init_kv_signal_per_query
Expand All @@ -25,6 +25,7 @@
__all__ = [
"get_block_shape_and_split_kv_block",
"append_attention",
"append_attention_with_output",
"open_shm_and_get_meta_signal",
"init_signal_layerwise",
"gqa_rope_write_cache",
Expand Down
Loading
Loading