Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions custom_ops/gpu_ops/append_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ std::vector<paddle::Tensor> AppendAttention(
return {paddle::Tensor{}};
}

void AppendAttentionWithOutput(
std::vector<paddle::Tensor> AppendAttentionWithOutput(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
Expand Down Expand Up @@ -761,6 +761,8 @@ void AppendAttentionWithOutput(
break;
}
}

return {fmha_out};
}


Expand Down Expand Up @@ -1063,9 +1065,7 @@ PD_BUILD_STATIC_OP(append_attention)
paddle::Optional("kv_signal_data"),
paddle::Optional("q_norm_weight"),
paddle::Optional("k_norm_weight")})
.Outputs({"fmha_out", "key_cache_out", "value_cache_out"})
.SetInplaceMap({{"key_cache", "key_cache_out"},
{"value_cache", "value_cache_out"}})
.Outputs({"fmha_out"})
.Attrs({"rms_norm_eps: float",
"compute_type: std::string",
"cache_quant_type: std::string",
Expand Down Expand Up @@ -1125,10 +1125,8 @@ PD_BUILD_STATIC_OP(append_attention_with_output)
paddle::Optional("kv_signal_data"),
paddle::Optional("q_norm_weight"),
paddle::Optional("k_norm_weight")})
.Outputs({"fmha_out_out", "qkv_out", "key_cache_out", "value_cache_out"})
.SetInplaceMap({{"fmha_out", "fmha_out_out"},
{"key_cache", "key_cache_out"},
{"value_cache", "value_cache_out"}})
.Outputs({"fmha_out_out"})
.SetInplaceMap({{"fmha_out", "fmha_out_out"}})
.Attrs({"rms_norm_eps: float",
"compute_type: std::string",
"cache_quant_type: std::string",
Expand Down
2 changes: 1 addition & 1 deletion custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ std::vector<paddle::Tensor> AppendAttention(
const int speculate_max_draft_token_num, const bool causal,
const bool speculate_decoder);

void AppendAttentionWithOutput(
std::vector<paddle::Tensor> AppendAttentionWithOutput(
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,15 +259,15 @@ def forward_mixed(
# 3. generate output tensor of different dtypes
if out_scale > 0.0:
if abs(quant_max_bound - 127) < 0.000001:
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="int8").to(qkv.place)
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="int8")
elif abs(quant_max_bound - 448) < 0.000001:
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="float8_e4m3fn").to(qkv.place)
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype="float8_e4m3fn")
else:
raise NotImplementedError("Only supported attr of quant_max_bound in ['127', '448'].")
else:
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype=D_type).to(qkv.place)
res = paddle.empty([token_nums, q_num_heads * head_dims], dtype=D_type)

append_attention_with_output(
res = append_attention_with_output(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def append_attention_with_output(
append_attention
"""
if current_platform.is_cuda():
append_attention_with_output_gpu(
return append_attention_with_output_gpu(
qkv,
key_cache,
value_cache,
Expand Down
Loading