Skip to content

Commit 33549c2

Browse files
authored
[INTEL_HPU] Support q/k rms norm in fused_qkv_rope op (PaddlePaddle#2187)
Signed-off-by: Fei Wang <[email protected]>
1 parent c6c401e commit 33549c2

File tree

2 files changed

+170
-19
lines changed

2 files changed

+170
-19
lines changed

backends/intel_hpu/custom_ops/llama_infer/fused_qkv_rope.cc

Lines changed: 158 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
namespace custom_kernel {
2323

2424
struct FusedQkvRopeParams {
25+
ns_LayerNormKernel::Params rmsnorm_params;
2526
int head_dim;
2627
int num_head;
2728
int kv_num_head;
@@ -31,6 +32,7 @@ struct FusedQkvRopeParams {
3132
bool with_qkv_biases = false;
3233
bool fp8_proj = false;
3334
bool fp8_out = false;
35+
bool use_qk_rmsnorm = false;
3436
};
3537

3638
class FusedQkvRope : public HpuFusedOperator {
@@ -48,6 +50,10 @@ class FusedQkvRope : public HpuFusedOperator {
4850
int qkv_biases_index = 3;
4951
int scale_input_index = (params.with_qkv_biases ? (qkv_biases_index + 1)
5052
: (rotary_embs_index + 1));
53+
int q_gamma_index =
54+
params.fp8_proj
55+
? params.fp8_out ? (scale_input_index + 5) : (scale_input_index + 2)
56+
: scale_input_index;
5157

5258
auto src = createTensorFromCT(&ct, src_index);
5359
auto qkv_weights = createTensorFromCT(&ct, qkv_weights_index);
@@ -180,7 +186,56 @@ class FusedQkvRope : public HpuFusedOperator {
180186

181187
std::vector<synTensor> inputs_q;
182188
std::vector<synTensor> outputs_q;
183-
inputs_q.push_back(q_split);
189+
std::vector<synTensor> inputs_k;
190+
std::vector<synTensor> outputs_k;
191+
192+
if (params.use_qk_rmsnorm) {
193+
auto q_rmsnorm = createTensorNoPresist("q_rmsnorm", dtype_, outs[0].dims);
194+
auto k_rmsnorm = createTensorNoPresist("k_rmsnorm", dtype_, kv_dims);
195+
synTensor q_gamma = createTensorFromCT(&ct, q_gamma_index);
196+
synTensor k_gamma = createTensorFromCT(&ct, q_gamma_index + 1);
197+
198+
auto tmp_q_dims = outs[0].dims;
199+
tmp_q_dims[3] = 1;
200+
auto q_rmsnorm_var =
201+
createTensorNoPresist("q_rmsnorm_var", dtype_, tmp_q_dims);
202+
203+
auto tmp_k_dims = kv_dims;
204+
tmp_k_dims[3] = 1;
205+
auto k_rmsnorm_var =
206+
createTensorNoPresist("k_rmsnorm_var", dtype_, tmp_k_dims);
207+
208+
std::vector<synTensor> rmsnorm_inputs_q;
209+
rmsnorm_inputs_q.push_back(q_split);
210+
rmsnorm_inputs_q.push_back(q_gamma);
211+
std::vector<synTensor> rmsnorm_outputs_q;
212+
rmsnorm_outputs_q.push_back(q_rmsnorm);
213+
rmsnorm_outputs_q.push_back(q_rmsnorm_var);
214+
215+
AddNodeRmsNorm<T>(rmsnorm_inputs_q,
216+
rmsnorm_outputs_q,
217+
params.rmsnorm_params,
218+
guid_ + "q_rmsnorm");
219+
220+
std::vector<synTensor> rmsnorm_inputs_k;
221+
rmsnorm_inputs_k.push_back(k_split);
222+
rmsnorm_inputs_k.push_back(k_gamma);
223+
std::vector<synTensor> rmsnorm_outputs_k;
224+
rmsnorm_outputs_k.push_back(k_rmsnorm);
225+
rmsnorm_outputs_k.push_back(k_rmsnorm_var);
226+
227+
AddNodeRmsNorm<T>(rmsnorm_inputs_k,
228+
rmsnorm_outputs_k,
229+
params.rmsnorm_params,
230+
guid_ + "k_rmsnorm");
231+
232+
inputs_q.push_back(q_rmsnorm);
233+
inputs_k.push_back(k_rmsnorm);
234+
} else {
235+
inputs_q.push_back(q_split);
236+
inputs_k.push_back(k_split);
237+
}
238+
184239
inputs_q.push_back(sin_sq);
185240
inputs_q.push_back(cos_sq);
186241

@@ -199,9 +254,6 @@ class FusedQkvRope : public HpuFusedOperator {
199254
: ROTARY_POS_EMBEDDING_MODE_PAIRWISE;
200255
AddNodeRope<T>(inputs_q, outputs_q, ropeParams, guid_ + "rope_q");
201256

202-
std::vector<synTensor> inputs_k;
203-
std::vector<synTensor> outputs_k;
204-
inputs_k.push_back(k_split);
205257
inputs_k.push_back(sin_sq);
206258
inputs_k.push_back(cos_sq);
207259

@@ -277,11 +329,14 @@ void FusedQkvRopeKernel(const Context& dev_ctx,
277329
const paddle::optional<phi::DenseTensor>& scale_v,
278330
phi::DenseTensor* query_states,
279331
phi::DenseTensor* key_value_states,
332+
const paddle::optional<phi::DenseTensor>& q_norm_weight,
333+
const paddle::optional<phi::DenseTensor>& k_norm_weight,
280334
const phi::Scalar& head_dim,
281335
const phi::Scalar& num_head,
282336
const phi::Scalar& total_batch,
283337
const phi::Scalar& transpose,
284-
const phi::Scalar& use_neox_style) {
338+
const phi::Scalar& use_neox_style,
339+
const phi::Scalar& epsilon) {
285340
int total_batch_ = total_batch.to<int>();
286341
std::vector<int64_t> src_dims = phi::vectorize<int64_t>(src.dims());
287342
int bsz_seqlen = src_dims[0];
@@ -335,6 +390,16 @@ void FusedQkvRopeKernel(const Context& dev_ctx,
335390
throw std::runtime_error(
336391
"Need both scale_input and scale_weight for FusedFp8QkvRopeKernel");
337392
}
393+
394+
if (q_norm_weight && k_norm_weight) {
395+
guid_prefix += "_qk_norm";
396+
ct.Add(q_norm_weight.get());
397+
ct.Add(k_norm_weight.get());
398+
} else if (q_norm_weight || k_norm_weight) {
399+
throw std::runtime_error(
400+
"Need both q_norm_weight and k_norm_weight for FusedQkvRopeKernel");
401+
}
402+
338403
guid_prefix += "_fwd_";
339404

340405
OpCacheOperator op_info;
@@ -345,6 +410,8 @@ void FusedQkvRopeKernel(const Context& dev_ctx,
345410
if (recipe == nullptr) {
346411
FusedQkvRopeParams params;
347412
memset(reinterpret_cast<void*>(&params), 0x00, sizeof(FusedQkvRopeParams));
413+
params.rmsnorm_params.epsValid = true;
414+
params.rmsnorm_params.eps = epsilon.to<float>();
348415
params.head_dim = head_dim_;
349416
params.num_head = num_head_;
350417
params.kv_num_head = kv_num_head;
@@ -360,6 +427,9 @@ void FusedQkvRopeKernel(const Context& dev_ctx,
360427
if (scale_q) {
361428
params.fp8_out = true;
362429
}
430+
if (q_norm_weight && k_norm_weight) {
431+
params.use_qk_rmsnorm = true;
432+
}
363433

364434
FusedQkvRope op(guid_prefix, op_info.datatype_);
365435
op.AddNode<T>(ct, params);
@@ -390,11 +460,14 @@ void CallFusedQkvRopeKernel(
390460
const paddle::optional<phi::DenseTensor>& scale_v,
391461
phi::DenseTensor* query_states,
392462
phi::DenseTensor* key_value_states,
463+
const paddle::optional<phi::DenseTensor>& q_norm,
464+
const paddle::optional<phi::DenseTensor>& k_norm,
393465
const phi::Scalar& head_dim,
394466
const phi::Scalar& num_head,
395467
const phi::Scalar& total_batch,
396468
const phi::Scalar& transpose,
397-
const phi::Scalar& use_neox_style) {
469+
const phi::Scalar& use_neox_style,
470+
const phi::Scalar& epsilon) {
398471
if (src.dtype() == phi::DataType::FLOAT16) {
399472
custom_kernel::FusedQkvRopeKernel<phi::dtype::float16>(dev_ctx,
400473
src,
@@ -408,11 +481,14 @@ void CallFusedQkvRopeKernel(
408481
scale_v,
409482
query_states,
410483
key_value_states,
484+
q_norm,
485+
k_norm,
411486
head_dim,
412487
num_head,
413488
total_batch,
414489
transpose,
415-
use_neox_style);
490+
use_neox_style,
491+
epsilon);
416492
} else if (src.dtype() == phi::DataType::BFLOAT16) {
417493
custom_kernel::FusedQkvRopeKernel<phi::dtype::bfloat16>(dev_ctx,
418494
src,
@@ -426,11 +502,14 @@ void CallFusedQkvRopeKernel(
426502
scale_v,
427503
query_states,
428504
key_value_states,
505+
q_norm,
506+
k_norm,
429507
head_dim,
430508
num_head,
431509
total_batch,
432510
transpose,
433-
use_neox_style);
511+
use_neox_style,
512+
epsilon);
434513
} else {
435514
throw std::runtime_error("Unsupported data type for FusedQkvRopeKernel");
436515
}
@@ -441,11 +520,14 @@ std::vector<paddle::Tensor> FusedQkvRopeImpl(
441520
const paddle::Tensor& qkv_weights,
442521
const paddle::optional<paddle::Tensor>& qkv_biases,
443522
const paddle::Tensor& rotary_embs,
523+
const paddle::optional<paddle::Tensor>& q_norm_weights,
524+
const paddle::optional<paddle::Tensor>& k_norm_weights,
444525
int head_dim,
445526
int num_head,
446527
int total_batch,
447528
bool transpose,
448-
bool use_neox_style) {
529+
bool use_neox_style,
530+
float epsilon) {
449531
auto dev_ctx = static_cast<const phi::CustomContext*>(
450532
paddle::experimental::DeviceContextPool::Instance().Get(src.place()));
451533
auto src_tensor = static_cast<const phi::DenseTensor*>(src.impl().get());
@@ -461,6 +543,22 @@ std::vector<paddle::Tensor> FusedQkvRopeImpl(
461543
qkv_biases_tensor = paddle::optional<phi::DenseTensor>(*qkv_biases_dt);
462544
}
463545

546+
auto q_norm_weights_tensor = paddle::optional<phi::DenseTensor>();
547+
if (q_norm_weights) {
548+
auto q_norm_weights_dt =
549+
static_cast<phi::DenseTensor*>(q_norm_weights->impl().get());
550+
q_norm_weights_tensor =
551+
paddle::optional<phi::DenseTensor>(*q_norm_weights_dt);
552+
}
553+
554+
auto k_norm_weights_tensor = paddle::optional<phi::DenseTensor>();
555+
if (k_norm_weights) {
556+
auto k_norm_weights_dt =
557+
static_cast<phi::DenseTensor*>(k_norm_weights->impl().get());
558+
k_norm_weights_tensor =
559+
paddle::optional<phi::DenseTensor>(*k_norm_weights_dt);
560+
}
561+
464562
// allocate memory on device.
465563
int64_t bsz = src.dims()[0];
466564
int64_t seq_len = bsz / total_batch;
@@ -492,11 +590,14 @@ std::vector<paddle::Tensor> FusedQkvRopeImpl(
492590
paddle::optional<phi::DenseTensor>(),
493591
query_states.get(),
494592
key_value_states.get(),
593+
q_norm_weights_tensor,
594+
k_norm_weights_tensor,
495595
phi::Scalar(head_dim),
496596
phi::Scalar(num_head),
497597
phi::Scalar(total_batch),
498598
phi::Scalar(transpose),
499-
phi::Scalar(use_neox_style));
599+
phi::Scalar(use_neox_style),
600+
phi::Scalar(epsilon));
500601
return {paddle::Tensor(query_states), paddle::Tensor(key_value_states)};
501602
}
502603

@@ -505,6 +606,8 @@ std::vector<std::vector<int64_t>> FusedQkvRopeShape(
505606
const std::vector<int64_t>& qkv_weights_shape,
506607
const paddle::optional<std::vector<int64_t>>& qkv_biases_shape,
507608
const std::vector<int64_t>& rotary_embs_shape,
609+
const paddle::optional<std::vector<int64_t>>& q_norm_weights_shape,
610+
const paddle::optional<std::vector<int64_t>>& k_norm_weights_shape,
508611
int head_dim,
509612
int num_head,
510613
int total_batch,
@@ -523,19 +626,26 @@ std::vector<paddle::DataType> FusedQkvRopeDtype(
523626
const paddle::DataType& src_dtype,
524627
const paddle::DataType& qkv_weights_dtype,
525628
const paddle::optional<paddle::DataType>& qkv_biases_dtype,
526-
const paddle::DataType& rotary_embs_dtype) {
629+
const paddle::DataType& rotary_embs_dtype,
630+
const paddle::optional<paddle::DataType>& q_norm_weights_dtype,
631+
const paddle::optional<paddle::DataType>& k_norm_weights_dtype) {
527632
return {src_dtype, src_dtype};
528633
}
529634

530635
PD_BUILD_OP(fused_qkv_rope_bf16)
531-
.Inputs(
532-
{"src", "qkv_weights", paddle::Optional("qkv_biases"), "rotary_embs"})
636+
.Inputs({"src",
637+
"qkv_weights",
638+
paddle::Optional("qkv_biases"),
639+
"rotary_embs",
640+
paddle::Optional("q_norm_weights"),
641+
paddle::Optional("k_norm_weights")})
533642
.Outputs({"query_states", "key_value_states"})
534643
.Attrs({"head_dim: int",
535644
"num_head: int",
536645
"total_batch: int",
537646
"transpose: bool",
538-
"use_neox_style: bool"})
647+
"use_neox_style: bool",
648+
"epsilon: float"})
539649
.SetKernelFn(PD_KERNEL(FusedQkvRopeImpl))
540650
.SetInferShapeFn(PD_INFER_SHAPE(FusedQkvRopeShape))
541651
.SetInferDtypeFn(PD_INFER_DTYPE(FusedQkvRopeDtype));
@@ -550,11 +660,14 @@ std::vector<paddle::Tensor> FusedFp8QkvRopeImpl(
550660
const paddle::optional<paddle::Tensor>& scale_q,
551661
const paddle::optional<paddle::Tensor>& scale_k,
552662
const paddle::optional<paddle::Tensor>& scale_v,
663+
const paddle::optional<paddle::Tensor>& q_norm_weights,
664+
const paddle::optional<paddle::Tensor>& k_norm_weights,
553665
int head_dim,
554666
int num_head,
555667
int total_batch,
556668
bool transpose,
557-
bool use_neox_style) {
669+
bool use_neox_style,
670+
float epsilon) {
558671
auto dev_ctx = static_cast<const phi::CustomContext*>(
559672
paddle::experimental::DeviceContextPool::Instance().Get(src.place()));
560673
auto src_tensor = static_cast<const phi::DenseTensor*>(src.impl().get());
@@ -599,6 +712,22 @@ std::vector<paddle::Tensor> FusedFp8QkvRopeImpl(
599712
scale_v_tensor = paddle::optional<phi::DenseTensor>(*scale_v_dt);
600713
}
601714

715+
auto q_norm_weights_tensor = paddle::optional<phi::DenseTensor>();
716+
if (q_norm_weights) {
717+
auto q_norm_weights_dt =
718+
static_cast<phi::DenseTensor*>(q_norm_weights->impl().get());
719+
q_norm_weights_tensor =
720+
paddle::optional<phi::DenseTensor>(*q_norm_weights_dt);
721+
}
722+
723+
auto k_norm_weights_tensor = paddle::optional<phi::DenseTensor>();
724+
if (k_norm_weights) {
725+
auto k_norm_weights_dt =
726+
static_cast<phi::DenseTensor*>(k_norm_weights->impl().get());
727+
k_norm_weights_tensor =
728+
paddle::optional<phi::DenseTensor>(*k_norm_weights_dt);
729+
}
730+
602731
// allocate memory on device.
603732
int64_t bsz = src.dims()[0];
604733
int64_t seq_len = bsz / total_batch;
@@ -636,11 +765,14 @@ std::vector<paddle::Tensor> FusedFp8QkvRopeImpl(
636765
scale_v_tensor,
637766
query_states.get(),
638767
key_value_states.get(),
768+
q_norm_weights_tensor,
769+
k_norm_weights_tensor,
639770
phi::Scalar(head_dim),
640771
phi::Scalar(num_head),
641772
phi::Scalar(total_batch),
642773
phi::Scalar(transpose),
643-
phi::Scalar(use_neox_style));
774+
phi::Scalar(use_neox_style),
775+
phi::Scalar(epsilon));
644776
return {paddle::Tensor(query_states), paddle::Tensor(key_value_states)};
645777
}
646778

@@ -651,6 +783,8 @@ std::vector<std::vector<int64_t>> FusedFp8QkvRopeShape(
651783
const std::vector<int64_t>& rotary_embs_shape,
652784
const std::vector<int64_t>& scale_input_shape,
653785
const std::vector<int64_t>& scale_weight_shape,
786+
const paddle::optional<std::vector<int64_t>>& q_norm_weights_shape,
787+
const paddle::optional<std::vector<int64_t>>& k_norm_weights_shape,
654788
int head_dim,
655789
int num_head,
656790
int total_batch,
@@ -671,7 +805,9 @@ std::vector<paddle::DataType> FusedFp8QkvRopeDtype(
671805
const paddle::optional<paddle::DataType>& qkv_biases_dtype,
672806
const paddle::DataType& rotary_embs_dtype,
673807
const paddle::DataType& scale_input_dtype,
674-
const paddle::DataType& scale_weight_dtype) {
808+
const paddle::DataType& scale_weight_dtype,
809+
const paddle::optional<paddle::DataType>& q_norm_weights_dtype,
810+
const paddle::optional<paddle::DataType>& k_norm_weights_dtype) {
675811
return {src_dtype, src_dtype};
676812
}
677813

@@ -684,13 +820,16 @@ PD_BUILD_OP(fused_qkv_rope)
684820
"scale_weight",
685821
"scale_q",
686822
"scale_k",
687-
"scale_v"})
823+
"scale_v",
824+
paddle::Optional("q_rmsnorm_weights"),
825+
paddle::Optional("k_rmsnorm_weights")})
688826
.Outputs({"query_states", "key_value_states"})
689827
.Attrs({"head_dim: int",
690828
"num_head: int",
691829
"total_batch: int",
692830
"transpose: bool",
693-
"use_neox_style: bool"})
831+
"use_neox_style: bool",
832+
"epsilon: float"})
694833
.SetKernelFn(PD_KERNEL(FusedFp8QkvRopeImpl))
695834
.SetInferShapeFn(PD_INFER_SHAPE(FusedFp8QkvRopeShape))
696835
.SetInferDtypeFn(PD_INFER_DTYPE(FusedFp8QkvRopeDtype));

0 commit comments

Comments
 (0)