Skip to content

Commit 440a44d

Browse files
committed
fix sot bug
1 parent 22038e6 commit 440a44d

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

custom_ops/gpu_ops/append_attention.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,7 +1057,7 @@ PD_BUILD_STATIC_OP(append_attention)
10571057
paddle::Optional("kv_signal_data"),
10581058
paddle::Optional("q_norm_weight"),
10591059
paddle::Optional("k_norm_weight")})
1060-
.Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"})
1060+
.Outputs({"fmha_out", "key_cache_out", "value_cache_out"})
10611061
.SetInplaceMap({{"key_cache", "key_cache_out"},
10621062
{"value_cache", "value_cache_out"}})
10631063
.Attrs({"rms_norm_eps: float",
@@ -1123,7 +1123,8 @@ PD_BUILD_STATIC_OP(append_attention_with_output)
11231123
.SetInplaceMap({{"fmha_out", "fmha_out_out"},
11241124
{"key_cache", "key_cache_out"},
11251125
{"value_cache", "value_cache_out"}})
1126-
.Attrs({"compute_type: std::string",
1126+
.Attrs({"rms_norm_eps: float",
1127+
"compute_type: std::string",
11271128
"cache_quant_type: std::string",
11281129
"use_neox_rotary_style: bool",
11291130
"rope_3d: bool",
@@ -1138,7 +1139,7 @@ PD_BUILD_STATIC_OP(append_attention_with_output)
11381139
"speculate_max_draft_token_num: int",
11391140
"causal: bool",
11401141
"speculate_decoder: bool",
1141-
"rms_norm_eps: float"})
1142+
})
11421143
.SetKernelFn(PD_KERNEL(AppendAttentionWithOutput))
11431144
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionWithOutputInferShape))
11441145
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionWithOutputInferDtype));

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ void AppendAttentionWithOutput(
107107
const paddle::Tensor &decoder_tile_ids_per_batch,
108108
const paddle::Tensor &decoder_num_blocks,
109109
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
110-
paddle::Tensor &res,
110+
paddle::Tensor &fmha_out,
111111
const paddle::optional<paddle::Tensor> &rotary_embs,
112112
const paddle::optional<paddle::Tensor> &attn_mask,
113113
const paddle::optional<paddle::Tensor> &qkv_bias,

0 commit comments

Comments
 (0)