@@ -413,10 +413,13 @@ void FusedGateMoeKernel(
413413 const phi::DenseTensor& hidden_states,
414414 const phi::DenseTensor& gate_weights,
415415 const paddle::optional<phi::DenseTensor>& gate_correction_bias,
416- const std::vector< phi::DenseTensor> & gate_up_weights,
417- const std::vector< phi::DenseTensor> & down_weights,
416+ const phi::DenseTensor& gate_up_weights,
417+ const phi::DenseTensor& down_weights,
418418 const paddle::optional<phi::DenseTensor>& hidden_states_scales,
419- const paddle::optional<std::vector<phi::DenseTensor>>& scales,
419+ const paddle::optional<std::vector<phi::DenseTensor>>&
420+ intermediate_hidden_states_scales,
421+ const paddle::optional<phi::DenseTensor>& gate_up_weights_scales,
422+ const paddle::optional<phi::DenseTensor>& down_weights_scales,
420423 phi::DenseTensor* final_hidden_states,
421424 const int top_k,
422425 const bool norm_topk_prob,
@@ -428,13 +431,18 @@ void FusedGateMoeKernel(
428431 const bool dynamic_scale,
429432 const int block_size,
430433 const int chunk_size) {
434+ std::vector<int64_t > gate_up_weights_dims =
435+ phi::vectorize<int64_t >(gate_up_weights.dims ());
436+ std::vector<int64_t > down_weights_dims =
437+ phi::vectorize<int64_t >(down_weights.dims ());
431438 FusedGateMoeParams params;
432439 memset (reinterpret_cast <void *>(¶ms), 0x00 , sizeof (FusedGateMoeParams));
433440 params.topk = top_k;
434441 params.norm_topk_prob = norm_topk_prob;
435442 params.permuted_weights = permuted_weights;
436- params.fused_gemm = (gate_up_weights.size () == down_weights.size ());
437- params.num_experts = down_weights.size ();
443+ params.fused_gemm = (gate_up_weights_dims[2 ] == down_weights_dims[1 ] * 2 );
444+ params.measurement_mode = measurement_mode;
445+ params.num_experts = gate_up_weights_dims[0 ];
438446 params.experts_min = experts_min;
439447 params.experts_max = experts_max;
440448 params.hidden_states_static_quant = false ;
@@ -456,17 +464,20 @@ void FusedGateMoeKernel(
456464 ct.Add (hidden_states_scales.get ());
457465 params.hidden_states_static_quant = true ;
458466 }
459- for (const auto & t : gate_up_weights) {
460- ct.Add (t);
461- }
462- for (const auto & t : down_weights) {
463- ct.Add (t);
464- }
465- if (scales) {
466- for (const auto & t : scales.get ()) {
467+ ct.AddN (gate_up_weights);
468+ ct.AddN (down_weights);
469+
470+ if (intermediate_hidden_states_scales) {
471+ for (const auto & t : intermediate_hidden_states_scales.get ()) {
467472 ct.Add (t);
468473 }
469474 }
475+ if (gate_up_weights_scales) {
476+ ct.AddN (gate_up_weights_scales.get ());
477+ }
478+ if (down_weights_scales) {
479+ ct.AddN (down_weights_scales.get ());
480+ }
470481
471482 ct.Add (*final_hidden_states, false );
472483
@@ -500,10 +511,13 @@ void CallFusedGateMoeKernel(
500511 const phi::DenseTensor& hidden_states,
501512 const phi::DenseTensor& gate_weights,
502513 const paddle::optional<phi::DenseTensor>& gate_correction_bias,
503- const std::vector< phi::DenseTensor> & gate_up_weights,
504- const std::vector< phi::DenseTensor> & down_weights,
514+ const phi::DenseTensor& gate_up_weights,
515+ const phi::DenseTensor& down_weights,
505516 const paddle::optional<phi::DenseTensor>& hidden_states_scales,
506- const paddle::optional<std::vector<phi::DenseTensor>>& scales,
517+ const paddle::optional<std::vector<phi::DenseTensor>>&
518+ intermediate_hidden_states_scales,
519+ const paddle::optional<phi::DenseTensor>& gate_up_weights_scales,
520+ const paddle::optional<phi::DenseTensor>& down_weights_scales,
507521 phi::DenseTensor* final_hidden_states,
508522 const int top_k,
509523 const bool norm_topk_prob,
@@ -528,7 +542,9 @@ void CallFusedGateMoeKernel(
528542 gate_up_weights,
529543 down_weights,
530544 hidden_states_scales,
531- scales,
545+ intermediate_hidden_states_scales,
546+ gate_up_weights_scales,
547+ down_weights_scales,
532548 final_hidden_states,
533549 top_k,
534550 norm_topk_prob,
@@ -550,7 +566,9 @@ void CallFusedGateMoeKernel(
550566 gate_up_weights,
551567 down_weights,
552568 hidden_states_scales,
553- scales,
569+ intermediate_hidden_states_scales,
570+ gate_up_weights_scales,
571+ down_weights_scales,
554572 final_hidden_states,
555573 top_k,
556574 norm_topk_prob,
@@ -572,8 +590,8 @@ std::vector<paddle::Tensor> FusedGateMoeForward(
572590 const paddle::Tensor& hidden_states,
573591 const paddle::Tensor& gate_weights,
574592 const paddle::optional<paddle::Tensor>& gate_correction_bias,
575- const std::vector< paddle::Tensor> & gate_up_weights,
576- const std::vector< paddle::Tensor> & down_weights,
593+ const paddle::Tensor& gate_up_weights,
594+ const paddle::Tensor& down_weights,
577595 const int top_k,
578596 const bool norm_topk_prob,
579597 const bool permuted_weights,
@@ -598,16 +616,10 @@ std::vector<paddle::Tensor> FusedGateMoeForward(
598616 paddle::optional<phi::DenseTensor>(*gate_correction_bias_dt);
599617 }
600618
601- std::vector<phi::DenseTensor> gate_up_weights_vec;
602- for (const auto & t : gate_up_weights) {
603- gate_up_weights_vec.push_back (
604- *static_cast <const phi::DenseTensor*>(t.impl ().get ()));
605- }
606- std::vector<phi::DenseTensor> down_weights_vec;
607- for (const auto & t : down_weights) {
608- down_weights_vec.push_back (
609- *static_cast <const phi::DenseTensor*>(t.impl ().get ()));
610- }
619+ auto gate_up_weights_tensor =
620+ static_cast <const phi::DenseTensor*>(gate_up_weights.impl ().get ());
621+ auto down_weights_tensor =
622+ static_cast <const phi::DenseTensor*>(down_weights.impl ().get ());
611623
612624 std::shared_ptr<phi::DenseTensor> final_hidden_states =
613625 std::make_shared<phi::DenseTensor>();
@@ -619,10 +631,12 @@ std::vector<paddle::Tensor> FusedGateMoeForward(
619631 *hidden_states_tensor,
620632 *gate_weights_tensor,
621633 gate_correction_tensor,
622- gate_up_weights_vec ,
623- down_weights_vec ,
634+ *gate_up_weights_tensor ,
635+ *down_weights_tensor ,
624636 paddle::optional<phi::DenseTensor>(), /* hidden_states_scale */
625- paddle::optional<std::vector<phi::DenseTensor>>(), /* scales */
637+ paddle::optional<std::vector<phi::DenseTensor>>(), /* intermediate */
638+ paddle::optional<phi::DenseTensor>(), /* gate_up_weights_scales */
639+ paddle::optional<phi::DenseTensor>(), /* down_weights_scales */
626640 final_hidden_states.get (),
627641 top_k,
628642 norm_topk_prob,
@@ -643,13 +657,13 @@ std::vector<paddle::Tensor> FusedGateMoeFP8Forward(
643657 const paddle::Tensor& hidden_states,
644658 const paddle::Tensor& gate_weights,
645659 const paddle::optional<paddle::Tensor>& gate_correction_bias,
646- const std::vector< paddle::Tensor> & gate_up_weights,
647- const std::vector< paddle::Tensor> & down_weights,
660+ const paddle::Tensor& gate_up_weights,
661+ const paddle::Tensor& down_weights,
648662 const paddle::optional<paddle::Tensor>& hidden_states_scales,
649663 const paddle::optional<std::vector<paddle::Tensor>>&
650664 intermediate_hidden_states_scales,
651- const std::vector< paddle::Tensor> & gate_up_weights_scales,
652- const std::vector< paddle::Tensor> & down_weights_scales,
665+ const paddle::Tensor& gate_up_weights_scales,
666+ const paddle::Tensor& down_weights_scales,
653667 const int top_k,
654668 const bool norm_topk_prob,
655669 const bool permuted_weights,
@@ -674,16 +688,10 @@ std::vector<paddle::Tensor> FusedGateMoeFP8Forward(
674688 paddle::optional<phi::DenseTensor>(*gate_correction_bias_dt);
675689 }
676690
677- std::vector<phi::DenseTensor> gate_up_weights_vec;
678- for (const auto & t : gate_up_weights) {
679- gate_up_weights_vec.push_back (
680- *static_cast <const phi::DenseTensor*>(t.impl ().get ()));
681- }
682- std::vector<phi::DenseTensor> down_weights_vec;
683- for (const auto & t : down_weights) {
684- down_weights_vec.push_back (
685- *static_cast <const phi::DenseTensor*>(t.impl ().get ()));
686- }
691+ auto gate_up_weights_tensor =
692+ static_cast <const phi::DenseTensor*>(gate_up_weights.impl ().get ());
693+ auto down_weights_tensor =
694+ static_cast <const phi::DenseTensor*>(down_weights.impl ().get ());
687695
688696 auto hidden_states_scales_tensor = paddle::optional<phi::DenseTensor>();
689697 if (hidden_states_scales) {
@@ -702,12 +710,17 @@ std::vector<paddle::Tensor> FusedGateMoeFP8Forward(
702710 *static_cast <const phi::DenseTensor*>(t.impl ().get ()));
703711 }
704712 }
705- for (const auto & t : gate_up_weights_scales) {
706- scales_vec.push_back (*static_cast <const phi::DenseTensor*>(t.impl ().get ()));
707- }
708- for (const auto & t : down_weights_scales) {
709- scales_vec.push_back (*static_cast <const phi::DenseTensor*>(t.impl ().get ()));
710- }
713+ auto gate_up_weights_scales_tensor = paddle::optional<phi::DenseTensor>();
714+ auto gate_up_weights_scales_dt =
715+ static_cast <const phi::DenseTensor*>(gate_up_weights_scales.impl ().get ());
716+ gate_up_weights_scales_tensor =
717+ paddle::optional<phi::DenseTensor>(*gate_up_weights_scales_dt);
718+
719+ auto down_weights_scales_tensor = paddle::optional<phi::DenseTensor>();
720+ auto down_weights_scales_dt =
721+ static_cast <const phi::DenseTensor*>(down_weights_scales.impl ().get ());
722+ down_weights_scales_tensor =
723+ paddle::optional<phi::DenseTensor>(*down_weights_scales_dt);
711724
712725 std::shared_ptr<phi::DenseTensor> final_hidden_states =
713726 std::make_shared<phi::DenseTensor>();
@@ -719,10 +732,12 @@ std::vector<paddle::Tensor> FusedGateMoeFP8Forward(
719732 *hidden_states_tensor,
720733 *gate_weights_tensor,
721734 gate_correction_tensor,
722- gate_up_weights_vec ,
723- down_weights_vec ,
735+ *gate_up_weights_tensor ,
736+ *down_weights_tensor ,
724737 hidden_states_scales_tensor,
725738 scales_vec,
739+ gate_up_weights_scales_tensor,
740+ down_weights_scales_tensor,
726741 final_hidden_states.get (),
727742 top_k,
728743 norm_topk_prob,
@@ -742,10 +757,10 @@ std::vector<paddle::Tensor> FusedGateMoeBlockWiseFP8Forward(
742757 const paddle::Tensor& hidden_states,
743758 const paddle::Tensor& gate_weights,
744759 const paddle::optional<paddle::Tensor>& gate_correction_bias,
745- const std::vector< paddle::Tensor> & gate_up_weights,
746- const std::vector< paddle::Tensor> & down_weights,
747- const std::vector< paddle::Tensor> & gate_up_weights_scales,
748- const std::vector< paddle::Tensor> & down_weights_scales,
760+ const paddle::Tensor& gate_up_weights,
761+ const paddle::Tensor& down_weights,
762+ const paddle::Tensor& gate_up_weights_scales,
763+ const paddle::Tensor& down_weights_scales,
749764 const int top_k,
750765 const bool norm_topk_prob,
751766 const bool permuted_weights,
@@ -771,24 +786,22 @@ std::vector<paddle::Tensor> FusedGateMoeBlockWiseFP8Forward(
771786 paddle::optional<phi::DenseTensor>(*gate_correction_bias_dt);
772787 }
773788
774- std::vector<phi::DenseTensor> gate_up_weights_vec;
775- for (const auto & t : gate_up_weights) {
776- gate_up_weights_vec.push_back (
777- *static_cast <const phi::DenseTensor*>(t.impl ().get ()));
778- }
779- std::vector<phi::DenseTensor> down_weights_vec;
780- for (const auto & t : down_weights) {
781- down_weights_vec.push_back (
782- *static_cast <const phi::DenseTensor*>(t.impl ().get ()));
783- }
789+ auto gate_up_weights_tensor =
790+ static_cast <const phi::DenseTensor*>(gate_up_weights.impl ().get ());
791+ auto down_weights_tensor =
792+ static_cast <const phi::DenseTensor*>(down_weights.impl ().get ());
784793
785- std::vector<phi::DenseTensor> scales_vec;
786- for (const auto & t : gate_up_weights_scales) {
787- scales_vec.push_back (*static_cast <const phi::DenseTensor*>(t.impl ().get ()));
788- }
789- for (const auto & t : down_weights_scales) {
790- scales_vec.push_back (*static_cast <const phi::DenseTensor*>(t.impl ().get ()));
791- }
794+ auto gate_up_weights_scales_tensor = paddle::optional<phi::DenseTensor>();
795+ auto gate_up_weights_scales_dt =
796+ static_cast <const phi::DenseTensor*>(gate_up_weights_scales.impl ().get ());
797+ gate_up_weights_scales_tensor =
798+ paddle::optional<phi::DenseTensor>(*gate_up_weights_scales_dt);
799+
800+ auto down_weights_scales_tensor = paddle::optional<phi::DenseTensor>();
801+ auto down_weights_scales_dt =
802+ static_cast <const phi::DenseTensor*>(down_weights_scales.impl ().get ());
803+ down_weights_scales_tensor =
804+ paddle::optional<phi::DenseTensor>(*down_weights_scales_dt);
792805
793806 std::shared_ptr<phi::DenseTensor> final_hidden_states =
794807 std::make_shared<phi::DenseTensor>();
@@ -800,10 +813,12 @@ std::vector<paddle::Tensor> FusedGateMoeBlockWiseFP8Forward(
800813 *hidden_states_tensor,
801814 *gate_weights_tensor,
802815 gate_correction_tensor,
803- gate_up_weights_vec ,
804- down_weights_vec ,
816+ *gate_up_weights_tensor ,
817+ *down_weights_tensor ,
805818 paddle::optional<phi::DenseTensor>(), /* hidden_states_scale */
806- scales_vec,
819+ paddle::optional<std::vector<phi::DenseTensor>>(), /* intermediate */
820+ gate_up_weights_scales_tensor,
821+ down_weights_scales_tensor,
807822 final_hidden_states.get (),
808823 top_k,
809824 norm_topk_prob,
@@ -845,8 +860,8 @@ PD_BUILD_OP(fused_gate_moe)
845860 .Inputs({" hidden_states" ,
846861 " gate_weights" ,
847862 paddle::Optional (" gate_correction_bias" ),
848- paddle::Vec ( " gate_up_weights" ) ,
849- paddle::Vec ( " down_weights" ) })
863+ " gate_up_weights" ,
864+ " down_weights" })
850865 .Outputs({" final_hidden_states" })
851866 .Attrs({" top_k: int" ,
852867 " norm_topk_prob: bool" ,
@@ -869,12 +884,12 @@ PD_BUILD_OP(fused_gate_moe_fp8)
869884 .Inputs({" hidden_states" ,
870885 " gate_weights" ,
871886 paddle::Optional (" gate_correction_bias" ),
872- paddle::Vec ( " gate_up_weights" ) ,
873- paddle::Vec ( " down_weights" ) ,
887+ " gate_up_weights" ,
888+ " down_weights" ,
874889 paddle::Optional (" hidden_states_scales" ),
875890 paddle::Optional (paddle::Vec (" intermediate_hidden_states_scales" )),
876- paddle::Vec ( " gate_up_weights_scales" ) ,
877- paddle::Vec ( " down_weights_scales" ) })
891+ " gate_up_weights_scales" ,
892+ " down_weights_scales" })
878893 .Outputs({" final_hidden_states" })
879894 .Attrs({" top_k: int" ,
880895 " norm_topk_prob: bool" ,
@@ -896,10 +911,10 @@ PD_BUILD_OP(fused_gate_moe_blockwise_fp8)
896911 .Inputs({" hidden_states" ,
897912 " gate_weights" ,
898913 paddle::Optional (" gate_correction_bias" ),
899- paddle::Vec ( " gate_up_weights" ) ,
900- paddle::Vec ( " down_weights" ) ,
901- paddle::Vec ( " gate_up_weights_scales" ) ,
902- paddle::Vec ( " down_weights_scales" ) })
914+ " gate_up_weights" ,
915+ " down_weights" ,
916+ " gate_up_weights_scales" ,
917+ " down_weights_scales" })
903918 .Outputs({" final_hidden_states" })
904919 .Attrs({" top_k: int" ,
905920 " norm_topk_prob: bool" ,
0 commit comments