2222namespace custom_kernel {
2323
2424struct FusedQkvRopeParams {
25+ ns_LayerNormKernel::Params rmsnorm_params;
2526 int head_dim;
2627 int num_head;
2728 int kv_num_head;
@@ -31,6 +32,7 @@ struct FusedQkvRopeParams {
3132 bool with_qkv_biases = false ;
3233 bool fp8_proj = false ;
3334 bool fp8_out = false ;
35+ bool use_qk_rmsnorm = false ;
3436};
3537
3638class FusedQkvRope : public HpuFusedOperator {
@@ -48,6 +50,10 @@ class FusedQkvRope : public HpuFusedOperator {
4850 int qkv_biases_index = 3 ;
4951 int scale_input_index = (params.with_qkv_biases ? (qkv_biases_index + 1 )
5052 : (rotary_embs_index + 1 ));
53+ int q_gamma_index =
54+ params.fp8_proj
55+ ? params.fp8_out ? (scale_input_index + 5 ) : (scale_input_index + 2 )
56+ : scale_input_index;
5157
5258 auto src = createTensorFromCT (&ct, src_index);
5359 auto qkv_weights = createTensorFromCT (&ct, qkv_weights_index);
@@ -180,7 +186,56 @@ class FusedQkvRope : public HpuFusedOperator {
180186
181187 std::vector<synTensor> inputs_q;
182188 std::vector<synTensor> outputs_q;
183- inputs_q.push_back (q_split);
189+ std::vector<synTensor> inputs_k;
190+ std::vector<synTensor> outputs_k;
191+
192+ if (params.use_qk_rmsnorm ) {
193+ auto q_rmsnorm = createTensorNoPresist (" q_rmsnorm" , dtype_, outs[0 ].dims );
194+ auto k_rmsnorm = createTensorNoPresist (" k_rmsnorm" , dtype_, kv_dims);
195+ synTensor q_gamma = createTensorFromCT (&ct, q_gamma_index);
196+ synTensor k_gamma = createTensorFromCT (&ct, q_gamma_index + 1 );
197+
198+ auto tmp_q_dims = outs[0 ].dims ;
199+ tmp_q_dims[3 ] = 1 ;
200+ auto q_rmsnorm_var =
201+ createTensorNoPresist (" q_rmsnorm_var" , dtype_, tmp_q_dims);
202+
203+ auto tmp_k_dims = kv_dims;
204+ tmp_k_dims[3 ] = 1 ;
205+ auto k_rmsnorm_var =
206+ createTensorNoPresist (" k_rmsnorm_var" , dtype_, tmp_k_dims);
207+
208+ std::vector<synTensor> rmsnorm_inputs_q;
209+ rmsnorm_inputs_q.push_back (q_split);
210+ rmsnorm_inputs_q.push_back (q_gamma);
211+ std::vector<synTensor> rmsnorm_outputs_q;
212+ rmsnorm_outputs_q.push_back (q_rmsnorm);
213+ rmsnorm_outputs_q.push_back (q_rmsnorm_var);
214+
215+ AddNodeRmsNorm<T>(rmsnorm_inputs_q,
216+ rmsnorm_outputs_q,
217+ params.rmsnorm_params ,
218+ guid_ + " q_rmsnorm" );
219+
220+ std::vector<synTensor> rmsnorm_inputs_k;
221+ rmsnorm_inputs_k.push_back (k_split);
222+ rmsnorm_inputs_k.push_back (k_gamma);
223+ std::vector<synTensor> rmsnorm_outputs_k;
224+ rmsnorm_outputs_k.push_back (k_rmsnorm);
225+ rmsnorm_outputs_k.push_back (k_rmsnorm_var);
226+
227+ AddNodeRmsNorm<T>(rmsnorm_inputs_k,
228+ rmsnorm_outputs_k,
229+ params.rmsnorm_params ,
230+ guid_ + " k_rmsnorm" );
231+
232+ inputs_q.push_back (q_rmsnorm);
233+ inputs_k.push_back (k_rmsnorm);
234+ } else {
235+ inputs_q.push_back (q_split);
236+ inputs_k.push_back (k_split);
237+ }
238+
184239 inputs_q.push_back (sin_sq);
185240 inputs_q.push_back (cos_sq);
186241
@@ -199,9 +254,6 @@ class FusedQkvRope : public HpuFusedOperator {
199254 : ROTARY_POS_EMBEDDING_MODE_PAIRWISE;
200255 AddNodeRope<T>(inputs_q, outputs_q, ropeParams, guid_ + " rope_q" );
201256
202- std::vector<synTensor> inputs_k;
203- std::vector<synTensor> outputs_k;
204- inputs_k.push_back (k_split);
205257 inputs_k.push_back (sin_sq);
206258 inputs_k.push_back (cos_sq);
207259
@@ -277,11 +329,14 @@ void FusedQkvRopeKernel(const Context& dev_ctx,
277329 const paddle::optional<phi::DenseTensor>& scale_v,
278330 phi::DenseTensor* query_states,
279331 phi::DenseTensor* key_value_states,
332+ const paddle::optional<phi::DenseTensor>& q_norm_weight,
333+ const paddle::optional<phi::DenseTensor>& k_norm_weight,
280334 const phi::Scalar& head_dim,
281335 const phi::Scalar& num_head,
282336 const phi::Scalar& total_batch,
283337 const phi::Scalar& transpose,
284- const phi::Scalar& use_neox_style) {
338+ const phi::Scalar& use_neox_style,
339+ const phi::Scalar& epsilon) {
285340 int total_batch_ = total_batch.to <int >();
286341 std::vector<int64_t > src_dims = phi::vectorize<int64_t >(src.dims ());
287342 int bsz_seqlen = src_dims[0 ];
@@ -335,6 +390,16 @@ void FusedQkvRopeKernel(const Context& dev_ctx,
335390 throw std::runtime_error (
336391 " Need both scale_input and scale_weight for FusedFp8QkvRopeKernel" );
337392 }
393+
394+ if (q_norm_weight && k_norm_weight) {
395+ guid_prefix += " _qk_norm" ;
396+ ct.Add (q_norm_weight.get ());
397+ ct.Add (k_norm_weight.get ());
398+ } else if (q_norm_weight || k_norm_weight) {
399+ throw std::runtime_error (
400+ " Need both q_norm_weight and k_norm_weight for FusedQkvRopeKernel" );
401+ }
402+
338403 guid_prefix += " _fwd_" ;
339404
340405 OpCacheOperator op_info;
@@ -345,6 +410,8 @@ void FusedQkvRopeKernel(const Context& dev_ctx,
345410 if (recipe == nullptr ) {
346411 FusedQkvRopeParams params;
347412 memset (reinterpret_cast <void *>(¶ms), 0x00 , sizeof (FusedQkvRopeParams));
413+ params.rmsnorm_params .epsValid = true ;
414+ params.rmsnorm_params .eps = epsilon.to <float >();
348415 params.head_dim = head_dim_;
349416 params.num_head = num_head_;
350417 params.kv_num_head = kv_num_head;
@@ -360,6 +427,9 @@ void FusedQkvRopeKernel(const Context& dev_ctx,
360427 if (scale_q) {
361428 params.fp8_out = true ;
362429 }
430+ if (q_norm_weight && k_norm_weight) {
431+ params.use_qk_rmsnorm = true ;
432+ }
363433
364434 FusedQkvRope op (guid_prefix, op_info.datatype_ );
365435 op.AddNode <T>(ct, params);
@@ -390,11 +460,14 @@ void CallFusedQkvRopeKernel(
390460 const paddle::optional<phi::DenseTensor>& scale_v,
391461 phi::DenseTensor* query_states,
392462 phi::DenseTensor* key_value_states,
463+ const paddle::optional<phi::DenseTensor>& q_norm,
464+ const paddle::optional<phi::DenseTensor>& k_norm,
393465 const phi::Scalar& head_dim,
394466 const phi::Scalar& num_head,
395467 const phi::Scalar& total_batch,
396468 const phi::Scalar& transpose,
397- const phi::Scalar& use_neox_style) {
469+ const phi::Scalar& use_neox_style,
470+ const phi::Scalar& epsilon) {
398471 if (src.dtype () == phi::DataType::FLOAT16) {
399472 custom_kernel::FusedQkvRopeKernel<phi::dtype::float16>(dev_ctx,
400473 src,
@@ -408,11 +481,14 @@ void CallFusedQkvRopeKernel(
408481 scale_v,
409482 query_states,
410483 key_value_states,
484+ q_norm,
485+ k_norm,
411486 head_dim,
412487 num_head,
413488 total_batch,
414489 transpose,
415- use_neox_style);
490+ use_neox_style,
491+ epsilon);
416492 } else if (src.dtype () == phi::DataType::BFLOAT16) {
417493 custom_kernel::FusedQkvRopeKernel<phi::dtype::bfloat16>(dev_ctx,
418494 src,
@@ -426,11 +502,14 @@ void CallFusedQkvRopeKernel(
426502 scale_v,
427503 query_states,
428504 key_value_states,
505+ q_norm,
506+ k_norm,
429507 head_dim,
430508 num_head,
431509 total_batch,
432510 transpose,
433- use_neox_style);
511+ use_neox_style,
512+ epsilon);
434513 } else {
435514 throw std::runtime_error (" Unsupported data type for FusedQkvRopeKernel" );
436515 }
@@ -441,11 +520,14 @@ std::vector<paddle::Tensor> FusedQkvRopeImpl(
441520 const paddle::Tensor& qkv_weights,
442521 const paddle::optional<paddle::Tensor>& qkv_biases,
443522 const paddle::Tensor& rotary_embs,
523+ const paddle::optional<paddle::Tensor>& q_norm_weights,
524+ const paddle::optional<paddle::Tensor>& k_norm_weights,
444525 int head_dim,
445526 int num_head,
446527 int total_batch,
447528 bool transpose,
448- bool use_neox_style) {
529+ bool use_neox_style,
530+ float epsilon) {
449531 auto dev_ctx = static_cast <const phi::CustomContext*>(
450532 paddle::experimental::DeviceContextPool::Instance ().Get (src.place ()));
451533 auto src_tensor = static_cast <const phi::DenseTensor*>(src.impl ().get ());
@@ -461,6 +543,22 @@ std::vector<paddle::Tensor> FusedQkvRopeImpl(
461543 qkv_biases_tensor = paddle::optional<phi::DenseTensor>(*qkv_biases_dt);
462544 }
463545
546+ auto q_norm_weights_tensor = paddle::optional<phi::DenseTensor>();
547+ if (q_norm_weights) {
548+ auto q_norm_weights_dt =
549+ static_cast <phi::DenseTensor*>(q_norm_weights->impl ().get ());
550+ q_norm_weights_tensor =
551+ paddle::optional<phi::DenseTensor>(*q_norm_weights_dt);
552+ }
553+
554+ auto k_norm_weights_tensor = paddle::optional<phi::DenseTensor>();
555+ if (k_norm_weights) {
556+ auto k_norm_weights_dt =
557+ static_cast <phi::DenseTensor*>(k_norm_weights->impl ().get ());
558+ k_norm_weights_tensor =
559+ paddle::optional<phi::DenseTensor>(*k_norm_weights_dt);
560+ }
561+
464562 // allocate memory on device.
465563 int64_t bsz = src.dims ()[0 ];
466564 int64_t seq_len = bsz / total_batch;
@@ -492,11 +590,14 @@ std::vector<paddle::Tensor> FusedQkvRopeImpl(
492590 paddle::optional<phi::DenseTensor>(),
493591 query_states.get (),
494592 key_value_states.get (),
593+ q_norm_weights_tensor,
594+ k_norm_weights_tensor,
495595 phi::Scalar (head_dim),
496596 phi::Scalar (num_head),
497597 phi::Scalar (total_batch),
498598 phi::Scalar (transpose),
499- phi::Scalar (use_neox_style));
599+ phi::Scalar (use_neox_style),
600+ phi::Scalar (epsilon));
500601 return {paddle::Tensor (query_states), paddle::Tensor (key_value_states)};
501602}
502603
@@ -505,6 +606,8 @@ std::vector<std::vector<int64_t>> FusedQkvRopeShape(
505606 const std::vector<int64_t >& qkv_weights_shape,
506607 const paddle::optional<std::vector<int64_t >>& qkv_biases_shape,
507608 const std::vector<int64_t >& rotary_embs_shape,
609+ const paddle::optional<std::vector<int64_t >>& q_norm_weights_shape,
610+ const paddle::optional<std::vector<int64_t >>& k_norm_weights_shape,
508611 int head_dim,
509612 int num_head,
510613 int total_batch,
@@ -523,19 +626,26 @@ std::vector<paddle::DataType> FusedQkvRopeDtype(
523626 const paddle::DataType& src_dtype,
524627 const paddle::DataType& qkv_weights_dtype,
525628 const paddle::optional<paddle::DataType>& qkv_biases_dtype,
526- const paddle::DataType& rotary_embs_dtype) {
629+ const paddle::DataType& rotary_embs_dtype,
630+ const paddle::optional<paddle::DataType>& q_norm_weights_dtype,
631+ const paddle::optional<paddle::DataType>& k_norm_weights_dtype) {
527632 return {src_dtype, src_dtype};
528633}
529634
530635PD_BUILD_OP (fused_qkv_rope_bf16)
531- .Inputs(
532- {" src" , " qkv_weights" , paddle::Optional (" qkv_biases" ), " rotary_embs" })
636+ .Inputs({" src" ,
637+ " qkv_weights" ,
638+ paddle::Optional (" qkv_biases" ),
639+ " rotary_embs" ,
640+ paddle::Optional (" q_norm_weights" ),
641+ paddle::Optional (" k_norm_weights" )})
533642 .Outputs({" query_states" , " key_value_states" })
534643 .Attrs({" head_dim: int" ,
535644 " num_head: int" ,
536645 " total_batch: int" ,
537646 " transpose: bool" ,
538- " use_neox_style: bool" })
647+ " use_neox_style: bool" ,
648+ " epsilon: float" })
539649 .SetKernelFn(PD_KERNEL(FusedQkvRopeImpl))
540650 .SetInferShapeFn(PD_INFER_SHAPE(FusedQkvRopeShape))
541651 .SetInferDtypeFn(PD_INFER_DTYPE(FusedQkvRopeDtype));
@@ -550,11 +660,14 @@ std::vector<paddle::Tensor> FusedFp8QkvRopeImpl(
550660 const paddle::optional<paddle::Tensor>& scale_q,
551661 const paddle::optional<paddle::Tensor>& scale_k,
552662 const paddle::optional<paddle::Tensor>& scale_v,
663+ const paddle::optional<paddle::Tensor>& q_norm_weights,
664+ const paddle::optional<paddle::Tensor>& k_norm_weights,
553665 int head_dim,
554666 int num_head,
555667 int total_batch,
556668 bool transpose,
557- bool use_neox_style) {
669+ bool use_neox_style,
670+ float epsilon) {
558671 auto dev_ctx = static_cast <const phi::CustomContext*>(
559672 paddle::experimental::DeviceContextPool::Instance ().Get (src.place ()));
560673 auto src_tensor = static_cast <const phi::DenseTensor*>(src.impl ().get ());
@@ -599,6 +712,22 @@ std::vector<paddle::Tensor> FusedFp8QkvRopeImpl(
599712 scale_v_tensor = paddle::optional<phi::DenseTensor>(*scale_v_dt);
600713 }
601714
715+ auto q_norm_weights_tensor = paddle::optional<phi::DenseTensor>();
716+ if (q_norm_weights) {
717+ auto q_norm_weights_dt =
718+ static_cast <phi::DenseTensor*>(q_norm_weights->impl ().get ());
719+ q_norm_weights_tensor =
720+ paddle::optional<phi::DenseTensor>(*q_norm_weights_dt);
721+ }
722+
723+ auto k_norm_weights_tensor = paddle::optional<phi::DenseTensor>();
724+ if (k_norm_weights) {
725+ auto k_norm_weights_dt =
726+ static_cast <phi::DenseTensor*>(k_norm_weights->impl ().get ());
727+ k_norm_weights_tensor =
728+ paddle::optional<phi::DenseTensor>(*k_norm_weights_dt);
729+ }
730+
602731 // allocate memory on device.
603732 int64_t bsz = src.dims ()[0 ];
604733 int64_t seq_len = bsz / total_batch;
@@ -636,11 +765,14 @@ std::vector<paddle::Tensor> FusedFp8QkvRopeImpl(
636765 scale_v_tensor,
637766 query_states.get (),
638767 key_value_states.get (),
768+ q_norm_weights_tensor,
769+ k_norm_weights_tensor,
639770 phi::Scalar (head_dim),
640771 phi::Scalar (num_head),
641772 phi::Scalar (total_batch),
642773 phi::Scalar (transpose),
643- phi::Scalar (use_neox_style));
774+ phi::Scalar (use_neox_style),
775+ phi::Scalar (epsilon));
644776 return {paddle::Tensor (query_states), paddle::Tensor (key_value_states)};
645777}
646778
@@ -651,6 +783,8 @@ std::vector<std::vector<int64_t>> FusedFp8QkvRopeShape(
651783 const std::vector<int64_t >& rotary_embs_shape,
652784 const std::vector<int64_t >& scale_input_shape,
653785 const std::vector<int64_t >& scale_weight_shape,
786+ const paddle::optional<std::vector<int64_t >>& q_norm_weights_shape,
787+ const paddle::optional<std::vector<int64_t >>& k_norm_weights_shape,
654788 int head_dim,
655789 int num_head,
656790 int total_batch,
@@ -671,7 +805,9 @@ std::vector<paddle::DataType> FusedFp8QkvRopeDtype(
671805 const paddle::optional<paddle::DataType>& qkv_biases_dtype,
672806 const paddle::DataType& rotary_embs_dtype,
673807 const paddle::DataType& scale_input_dtype,
674- const paddle::DataType& scale_weight_dtype) {
808+ const paddle::DataType& scale_weight_dtype,
809+ const paddle::optional<paddle::DataType>& q_norm_weights_dtype,
810+ const paddle::optional<paddle::DataType>& k_norm_weights_dtype) {
675811 return {src_dtype, src_dtype};
676812}
677813
@@ -684,13 +820,16 @@ PD_BUILD_OP(fused_qkv_rope)
684820 " scale_weight" ,
685821 " scale_q" ,
686822 " scale_k" ,
687- " scale_v" })
823+ " scale_v" ,
824+ paddle::Optional (" q_rmsnorm_weights" ),
825+ paddle::Optional (" k_rmsnorm_weights" )})
688826 .Outputs({" query_states" , " key_value_states" })
689827 .Attrs({" head_dim: int" ,
690828 " num_head: int" ,
691829 " total_batch: int" ,
692830 " transpose: bool" ,
693- " use_neox_style: bool" })
831+ " use_neox_style: bool" ,
832+ " epsilon: float" })
694833 .SetKernelFn(PD_KERNEL(FusedFp8QkvRopeImpl))
695834 .SetInferShapeFn(PD_INFER_SHAPE(FusedFp8QkvRopeShape))
696835 .SetInferDtypeFn(PD_INFER_DTYPE(FusedFp8QkvRopeDtype));
0 commit comments