Skip to content

Commit ed6bff2

Browse files
authored
fix custom op order rms_norm_eps (#3348)
1 parent 8224b21 commit ed6bff2

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

custom_ops/gpu_ops/append_attention.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -739,12 +739,13 @@ PD_BUILD_STATIC_OP(append_attention)
739739
paddle::Optional("out_linear_shifts"),
740740
paddle::Optional("out_linear_smooths"),
741741
paddle::Optional("kv_signal_data"),
742-
paddle::Optional("q_norm_weight"),
743-
paddle::Optional("k_norm_weight")})
742+
paddle::Optional("q_norm_weight"),
743+
paddle::Optional("k_norm_weight")})
744744
.Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"})
745745
.SetInplaceMap({{"key_cache", "key_cache_out"},
746746
{"value_cache", "value_cache_out"}})
747-
.Attrs({"compute_type: std::string",
747+
.Attrs({"rms_norm_eps: float",
748+
"compute_type: std::string",
748749
"cache_quant_type: std::string",
749750
"use_neox_rotary_style: bool",
750751
"rope_3d: bool",
@@ -759,7 +760,7 @@ PD_BUILD_STATIC_OP(append_attention)
759760
"speculate_max_draft_token_num: int",
760761
"causal: bool",
761762
"speculate_decoder: bool",
762-
"rms_norm_eps: float"})
763+
})
763764
.SetKernelFn(PD_KERNEL(AppendAttention))
764765
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
765766
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype));

0 commit comments

Comments
 (0)