Skip to content

Commit 7a8cc78

Browse files
authored
[INTEL_HPU] MoE weights from list to stack (#1871)
1 parent 1ca54ab commit 7a8cc78

File tree

2 files changed

+180
-97
lines changed

2 files changed

+180
-97
lines changed

backends/intel_hpu/custom_ops/llama_infer/fused_gate_moe.cc

Lines changed: 112 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -344,9 +344,11 @@ void FusedGateMoeKernel(
344344
const phi::DenseTensor& hidden_states,
345345
const phi::DenseTensor& gate_out,
346346
const paddle::optional<phi::DenseTensor>& gate_correction_bias,
347-
const std::vector<phi::DenseTensor>& gate_up_weights,
348-
const std::vector<phi::DenseTensor>& down_weights,
349-
const paddle::optional<std::vector<phi::DenseTensor>>& scales,
347+
const phi::DenseTensor& gate_up_weights,
348+
const phi::DenseTensor& down_weights,
349+
const paddle::optional<phi::DenseTensor>& intermediate_hidden_states_scales,
350+
const paddle::optional<phi::DenseTensor>& gate_up_weights_scales,
351+
const paddle::optional<phi::DenseTensor>& down_weights_scales,
350352
phi::DenseTensor* final_hidden_states,
351353
const int top_k,
352354
const bool moe_use_gate_correction_bias,
@@ -358,14 +360,20 @@ void FusedGateMoeKernel(
358360
const bool measurement_mode,
359361
const bool dynamic_scale,
360362
const int block_size) {
363+
std::vector<int64_t> gate_up_weights_dims =
364+
phi::vectorize<int64_t>(gate_up_weights.dims());
365+
std::vector<int64_t> down_weights_dims =
366+
phi::vectorize<int64_t>(down_weights.dims());
361367
FusedGateMoeParams params;
362368
memset(reinterpret_cast<void*>(&params), 0x00, sizeof(FusedGateMoeParams));
363369
params.topk = top_k;
364370
params.moe_use_gate_correction_bias = moe_use_gate_correction_bias;
365371
params.norm_topk_prob = norm_topk_prob;
366372
params.permuted_weights = permuted_weights;
367-
params.fused_gemm = (gate_up_weights.size() == down_weights.size());
368-
params.num_experts = down_weights.size();
373+
// TODO(yanfeich): add optional up_weights
374+
params.fused_gemm = (gate_up_weights_dims[2] == down_weights_dims[1] * 2);
375+
params.measurement_mode = measurement_mode;
376+
params.num_experts = gate_up_weights_dims[0];
369377
params.experts_min = experts_min;
370378
params.experts_max = experts_max;
371379
params.dynamic_scale = dynamic_scale;
@@ -380,16 +388,18 @@ void FusedGateMoeKernel(
380388
if (moe_use_gate_correction_bias) {
381389
ct.Add(gate_correction_bias.get());
382390
}
383-
for (const auto& t : gate_up_weights) {
384-
ct.Add(t);
391+
392+
ct.AddN(gate_up_weights);
393+
ct.AddN(down_weights);
394+
395+
if (intermediate_hidden_states_scales) {
396+
ct.AddN(intermediate_hidden_states_scales.get());
385397
}
386-
for (const auto& t : down_weights) {
387-
ct.Add(t);
398+
if (gate_up_weights_scales) {
399+
ct.AddN(gate_up_weights_scales.get());
388400
}
389-
if (scales) {
390-
for (const auto& t : scales.get()) {
391-
ct.Add(t);
392-
}
401+
if (down_weights_scales) {
402+
ct.AddN(down_weights_scales.get());
393403
}
394404

395405
ct.Add(*final_hidden_states, false);
@@ -424,9 +434,11 @@ void CallFusedGateMoeKernel(
424434
const phi::DenseTensor& hidden_states,
425435
const phi::DenseTensor& gate_out,
426436
const paddle::optional<phi::DenseTensor>& gate_correction_bias,
427-
const std::vector<phi::DenseTensor>& gate_up_weights,
428-
const std::vector<phi::DenseTensor>& down_weights,
429-
const paddle::optional<std::vector<phi::DenseTensor>>& scales,
437+
const phi::DenseTensor& gate_up_weights,
438+
const phi::DenseTensor& down_weights,
439+
const paddle::optional<phi::DenseTensor>& intermediate_hidden_states_scales,
440+
const paddle::optional<phi::DenseTensor>& gate_up_weights_scales,
441+
const paddle::optional<phi::DenseTensor>& down_weights_scales,
430442
phi::DenseTensor* final_hidden_states,
431443
const int top_k,
432444
const bool moe_use_gate_correction_bias,
@@ -450,7 +462,9 @@ void CallFusedGateMoeKernel(
450462
gate_correction_bias,
451463
gate_up_weights,
452464
down_weights,
453-
scales,
465+
intermediate_hidden_states_scales,
466+
gate_up_weights_scales,
467+
down_weights_scales,
454468
final_hidden_states,
455469
top_k,
456470
moe_use_gate_correction_bias,
@@ -471,7 +485,9 @@ void CallFusedGateMoeKernel(
471485
gate_correction_bias,
472486
gate_up_weights,
473487
down_weights,
474-
scales,
488+
intermediate_hidden_states_scales,
489+
gate_up_weights_scales,
490+
down_weights_scales,
475491
final_hidden_states,
476492
top_k,
477493
moe_use_gate_correction_bias,
@@ -493,8 +509,8 @@ std::vector<paddle::Tensor> FusedGateMoeForward(
493509
const paddle::Tensor& hidden_states,
494510
const paddle::Tensor& gate_out,
495511
const paddle::optional<paddle::Tensor>& gate_correction_bias,
496-
const std::vector<paddle::Tensor>& gate_up_weights,
497-
const std::vector<paddle::Tensor>& down_weights,
512+
const paddle::Tensor& gate_up_weights,
513+
const paddle::Tensor& down_weights,
498514
const int top_k,
499515
const bool moe_use_gate_correction_bias,
500516
const bool norm_topk_prob,
@@ -519,16 +535,10 @@ std::vector<paddle::Tensor> FusedGateMoeForward(
519535
paddle::optional<phi::DenseTensor>(*gate_correction_bias_dt);
520536
}
521537

522-
std::vector<phi::DenseTensor> gate_up_weights_vec;
523-
for (const auto& t : gate_up_weights) {
524-
gate_up_weights_vec.push_back(
525-
*static_cast<const phi::DenseTensor*>(t.impl().get()));
526-
}
527-
std::vector<phi::DenseTensor> down_weights_vec;
528-
for (const auto& t : down_weights) {
529-
down_weights_vec.push_back(
530-
*static_cast<const phi::DenseTensor*>(t.impl().get()));
531-
}
538+
auto gate_up_weights_tensor =
539+
static_cast<const phi::DenseTensor*>(gate_up_weights.impl().get());
540+
auto down_weights_tensor =
541+
static_cast<const phi::DenseTensor*>(down_weights.impl().get());
532542

533543
std::shared_ptr<phi::DenseTensor> final_hidden_states =
534544
std::make_shared<phi::DenseTensor>();
@@ -540,9 +550,11 @@ std::vector<paddle::Tensor> FusedGateMoeForward(
540550
*hidden_states_tensor,
541551
*gate_out_tensor,
542552
gate_correction_tensor,
543-
gate_up_weights_vec,
544-
down_weights_vec,
545-
paddle::optional<std::vector<phi::DenseTensor>>(), /* scales */
553+
*gate_up_weights_tensor,
554+
*down_weights_tensor,
555+
paddle::optional<phi::DenseTensor>(), /* int..hid..st.._scales_tensor */
556+
paddle::optional<phi::DenseTensor>(), /* gate_up_weights_scales_tensor */
557+
paddle::optional<phi::DenseTensor>(), /* down_weights_scales_tensor */
546558
final_hidden_states.get(),
547559
top_k,
548560
moe_use_gate_correction_bias,
@@ -563,12 +575,11 @@ std::vector<paddle::Tensor> FusedGateMoeFP8Forward(
563575
const paddle::Tensor& hidden_states,
564576
const paddle::Tensor& gate_out,
565577
const paddle::optional<paddle::Tensor>& gate_correction_bias,
566-
const std::vector<paddle::Tensor>& gate_up_weights,
567-
const std::vector<paddle::Tensor>& down_weights,
568-
const paddle::optional<std::vector<paddle::Tensor>>&
569-
intermediate_hidden_states_scales,
570-
const std::vector<paddle::Tensor>& gate_up_weights_scales,
571-
const std::vector<paddle::Tensor>& down_weights_scales,
578+
const paddle::Tensor& gate_up_weights,
579+
const paddle::Tensor& down_weights,
580+
const paddle::optional<paddle::Tensor>& intermediate_hidden_states_scales,
581+
const paddle::Tensor& gate_up_weights_scales,
582+
const paddle::Tensor& down_weights_scales,
572583
const int top_k,
573584
const bool moe_use_gate_correction_bias,
574585
const bool norm_topk_prob,
@@ -593,33 +604,35 @@ std::vector<paddle::Tensor> FusedGateMoeFP8Forward(
593604
paddle::optional<phi::DenseTensor>(*gate_correction_bias_dt);
594605
}
595606

596-
std::vector<phi::DenseTensor> gate_up_weights_vec;
597-
for (const auto& t : gate_up_weights) {
598-
gate_up_weights_vec.push_back(
599-
*static_cast<const phi::DenseTensor*>(t.impl().get()));
600-
}
601-
std::vector<phi::DenseTensor> down_weights_vec;
602-
for (const auto& t : down_weights) {
603-
down_weights_vec.push_back(
604-
*static_cast<const phi::DenseTensor*>(t.impl().get()));
605-
}
607+
auto gate_up_weights_tensor =
608+
static_cast<const phi::DenseTensor*>(gate_up_weights.impl().get());
609+
auto down_weights_tensor =
610+
static_cast<const phi::DenseTensor*>(down_weights.impl().get());
606611

607612
bool dynamic_scale = true;
608-
std::vector<phi::DenseTensor> scales_vec;
613+
auto intermediate_hidden_states_scales_tensor =
614+
paddle::optional<phi::DenseTensor>();
609615
if (intermediate_hidden_states_scales) {
610616
dynamic_scale = false;
611-
for (const auto& t : intermediate_hidden_states_scales.get()) {
612-
scales_vec.push_back(
613-
*static_cast<const phi::DenseTensor*>(t.impl().get()));
614-
}
615-
}
616-
for (const auto& t : gate_up_weights_scales) {
617-
scales_vec.push_back(*static_cast<const phi::DenseTensor*>(t.impl().get()));
618-
}
619-
for (const auto& t : down_weights_scales) {
620-
scales_vec.push_back(*static_cast<const phi::DenseTensor*>(t.impl().get()));
617+
auto intermediate_hidden_states_scales_dt = static_cast<phi::DenseTensor*>(
618+
intermediate_hidden_states_scales->impl().get());
619+
intermediate_hidden_states_scales_tensor =
620+
paddle::optional<phi::DenseTensor>(
621+
*intermediate_hidden_states_scales_dt);
621622
}
622623

624+
auto gate_up_weights_scales_tensor = paddle::optional<phi::DenseTensor>();
625+
auto gate_up_weights_scales_dt =
626+
static_cast<const phi::DenseTensor*>(gate_up_weights_scales.impl().get());
627+
gate_up_weights_scales_tensor =
628+
paddle::optional<phi::DenseTensor>(*gate_up_weights_scales_dt);
629+
630+
auto down_weights_scales_tensor = paddle::optional<phi::DenseTensor>();
631+
auto down_weights_scales_dt =
632+
static_cast<const phi::DenseTensor*>(down_weights_scales.impl().get());
633+
down_weights_scales_tensor =
634+
paddle::optional<phi::DenseTensor>(*down_weights_scales_dt);
635+
623636
std::shared_ptr<phi::DenseTensor> final_hidden_states =
624637
std::make_shared<phi::DenseTensor>();
625638
final_hidden_states->Resize(hidden_states.dims());
@@ -630,9 +643,11 @@ std::vector<paddle::Tensor> FusedGateMoeFP8Forward(
630643
*hidden_states_tensor,
631644
*gate_out_tensor,
632645
gate_correction_tensor,
633-
gate_up_weights_vec,
634-
down_weights_vec,
635-
scales_vec,
646+
*gate_up_weights_tensor,
647+
*down_weights_tensor,
648+
intermediate_hidden_states_scales_tensor,
649+
gate_up_weights_scales_tensor,
650+
down_weights_scales_tensor,
636651
final_hidden_states.get(),
637652
top_k,
638653
moe_use_gate_correction_bias,
@@ -652,10 +667,10 @@ std::vector<paddle::Tensor> FusedGateMoeBlockWiseFP8Forward(
652667
const paddle::Tensor& hidden_states,
653668
const paddle::Tensor& gate_out,
654669
const paddle::optional<paddle::Tensor>& gate_correction_bias,
655-
const std::vector<paddle::Tensor>& gate_up_weights,
656-
const std::vector<paddle::Tensor>& down_weights,
657-
const std::vector<paddle::Tensor>& gate_up_weights_scales,
658-
const std::vector<paddle::Tensor>& down_weights_scales,
670+
const paddle::Tensor& gate_up_weights,
671+
const paddle::Tensor& down_weights,
672+
const paddle::Tensor& gate_up_weights_scales,
673+
const paddle::Tensor& down_weights_scales,
659674
const int top_k,
660675
const bool moe_use_gate_correction_bias,
661676
const bool norm_topk_prob,
@@ -681,24 +696,22 @@ std::vector<paddle::Tensor> FusedGateMoeBlockWiseFP8Forward(
681696
paddle::optional<phi::DenseTensor>(*gate_correction_bias_dt);
682697
}
683698

684-
std::vector<phi::DenseTensor> gate_up_weights_vec;
685-
for (const auto& t : gate_up_weights) {
686-
gate_up_weights_vec.push_back(
687-
*static_cast<const phi::DenseTensor*>(t.impl().get()));
688-
}
689-
std::vector<phi::DenseTensor> down_weights_vec;
690-
for (const auto& t : down_weights) {
691-
down_weights_vec.push_back(
692-
*static_cast<const phi::DenseTensor*>(t.impl().get()));
693-
}
699+
auto gate_up_weights_tensor =
700+
static_cast<const phi::DenseTensor*>(gate_up_weights.impl().get());
701+
auto down_weights_tensor =
702+
static_cast<const phi::DenseTensor*>(down_weights.impl().get());
694703

695-
std::vector<phi::DenseTensor> scales_vec;
696-
for (const auto& t : gate_up_weights_scales) {
697-
scales_vec.push_back(*static_cast<const phi::DenseTensor*>(t.impl().get()));
698-
}
699-
for (const auto& t : down_weights_scales) {
700-
scales_vec.push_back(*static_cast<const phi::DenseTensor*>(t.impl().get()));
701-
}
704+
auto gate_up_weights_scales_tensor = paddle::optional<phi::DenseTensor>();
705+
auto gate_up_weights_scales_dt =
706+
static_cast<const phi::DenseTensor*>(gate_up_weights_scales.impl().get());
707+
gate_up_weights_scales_tensor =
708+
paddle::optional<phi::DenseTensor>(*gate_up_weights_scales_dt);
709+
710+
auto down_weights_scales_tensor = paddle::optional<phi::DenseTensor>();
711+
auto down_weights_scales_dt =
712+
static_cast<const phi::DenseTensor*>(down_weights_scales.impl().get());
713+
down_weights_scales_tensor =
714+
paddle::optional<phi::DenseTensor>(*down_weights_scales_dt);
702715

703716
std::shared_ptr<phi::DenseTensor> final_hidden_states =
704717
std::make_shared<phi::DenseTensor>();
@@ -710,9 +723,11 @@ std::vector<paddle::Tensor> FusedGateMoeBlockWiseFP8Forward(
710723
*hidden_states_tensor,
711724
*gate_out_tensor,
712725
gate_correction_tensor,
713-
gate_up_weights_vec,
714-
down_weights_vec,
715-
scales_vec,
726+
*gate_up_weights_tensor,
727+
*down_weights_tensor,
728+
paddle::optional<phi::DenseTensor>(), /* scales */
729+
gate_up_weights_scales_tensor,
730+
down_weights_scales_tensor,
716731
final_hidden_states.get(),
717732
top_k,
718733
moe_use_gate_correction_bias,
@@ -755,8 +770,8 @@ PD_BUILD_OP(fused_gate_moe)
755770
.Inputs({"hidden_states",
756771
"gate_out",
757772
paddle::Optional("gate_correction_bias"),
758-
paddle::Vec("gate_up_weights"),
759-
paddle::Vec("down_weights")})
773+
"gate_up_weights",
774+
"down_weights"})
760775
.Outputs({"final_hidden_states"})
761776
.Attrs({"top_k: int",
762777
"moe_use_gate_correction_bias: bool",
@@ -780,11 +795,11 @@ PD_BUILD_OP(fused_gate_moe_fp8)
780795
.Inputs({"hidden_states",
781796
"gate_out",
782797
paddle::Optional("gate_correction_bias"),
783-
paddle::Vec("gate_up_weights"),
784-
paddle::Vec("down_weights"),
798+
"gate_up_weights",
799+
"down_weights",
785800
paddle::Optional(paddle::Vec("intermediate_hidden_states_scales")),
786-
paddle::Vec("gate_up_weights_scales"),
787-
paddle::Vec("down_weights_scales")})
801+
"gate_up_weights_scales",
802+
"down_weights_scales"})
788803
.Outputs({"final_hidden_states"})
789804
.Attrs({"top_k: int",
790805
"moe_use_gate_correction_bias: bool",
@@ -807,10 +822,10 @@ PD_BUILD_OP(fused_gate_moe_blockwise_fp8)
807822
.Inputs({"hidden_states",
808823
"gate_out",
809824
paddle::Optional("gate_correction_bias"),
810-
paddle::Vec("gate_up_weights"),
811-
paddle::Vec("down_weights"),
812-
paddle::Vec("gate_up_weights_scales"),
813-
paddle::Vec("down_weights_scales")})
825+
"gate_up_weights",
826+
"down_weights",
827+
"gate_up_weights_scales",
828+
"down_weights_scales"})
814829
.Outputs({"final_hidden_states"})
815830
.Attrs({"top_k: int",
816831
"moe_use_gate_correction_bias: bool",

0 commit comments

Comments
 (0)