Skip to content

Commit cd2e016

Browse files
authored
[INTEL_HPU] MoE weights and scales from list to tensor (PaddlePaddle#2219)
* list to tensor for moe and channel wise * rebase update
1 parent f461d0e commit cd2e016

File tree

6 files changed

+259
-165
lines changed

6 files changed

+259
-165
lines changed

backends/intel_hpu/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ option(WITH_MKL "compile with mkl support" ON)
3030
option(WITH_ARM "compile with arm support" OFF)
3131

3232
set(PLUGIN_NAME "paddle-intel-hpu")
33-
set(PLUGIN_VERSION "0.0.1")
33+
set(PLUGIN_VERSION "0.0.2")
3434

3535
include(paddle)
3636
include(generic)

backends/intel_hpu/custom_ops/llama_infer/fused_gate_moe.cc

Lines changed: 104 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -413,10 +413,13 @@ void FusedGateMoeKernel(
413413
const phi::DenseTensor& hidden_states,
414414
const phi::DenseTensor& gate_weights,
415415
const paddle::optional<phi::DenseTensor>& gate_correction_bias,
416-
const std::vector<phi::DenseTensor>& gate_up_weights,
417-
const std::vector<phi::DenseTensor>& down_weights,
416+
const phi::DenseTensor& gate_up_weights,
417+
const phi::DenseTensor& down_weights,
418418
const paddle::optional<phi::DenseTensor>& hidden_states_scales,
419-
const paddle::optional<std::vector<phi::DenseTensor>>& scales,
419+
const paddle::optional<std::vector<phi::DenseTensor>>&
420+
intermediate_hidden_states_scales,
421+
const paddle::optional<phi::DenseTensor>& gate_up_weights_scales,
422+
const paddle::optional<phi::DenseTensor>& down_weights_scales,
420423
phi::DenseTensor* final_hidden_states,
421424
const int top_k,
422425
const bool norm_topk_prob,
@@ -428,13 +431,18 @@ void FusedGateMoeKernel(
428431
const bool dynamic_scale,
429432
const int block_size,
430433
const int chunk_size) {
434+
std::vector<int64_t> gate_up_weights_dims =
435+
phi::vectorize<int64_t>(gate_up_weights.dims());
436+
std::vector<int64_t> down_weights_dims =
437+
phi::vectorize<int64_t>(down_weights.dims());
431438
FusedGateMoeParams params;
432439
memset(reinterpret_cast<void*>(&params), 0x00, sizeof(FusedGateMoeParams));
433440
params.topk = top_k;
434441
params.norm_topk_prob = norm_topk_prob;
435442
params.permuted_weights = permuted_weights;
436-
params.fused_gemm = (gate_up_weights.size() == down_weights.size());
437-
params.num_experts = down_weights.size();
443+
params.fused_gemm = (gate_up_weights_dims[2] == down_weights_dims[1] * 2);
444+
params.measurement_mode = measurement_mode;
445+
params.num_experts = gate_up_weights_dims[0];
438446
params.experts_min = experts_min;
439447
params.experts_max = experts_max;
440448
params.hidden_states_static_quant = false;
@@ -456,17 +464,20 @@ void FusedGateMoeKernel(
456464
ct.Add(hidden_states_scales.get());
457465
params.hidden_states_static_quant = true;
458466
}
459-
for (const auto& t : gate_up_weights) {
460-
ct.Add(t);
461-
}
462-
for (const auto& t : down_weights) {
463-
ct.Add(t);
464-
}
465-
if (scales) {
466-
for (const auto& t : scales.get()) {
467+
ct.AddN(gate_up_weights);
468+
ct.AddN(down_weights);
469+
470+
if (intermediate_hidden_states_scales) {
471+
for (const auto& t : intermediate_hidden_states_scales.get()) {
467472
ct.Add(t);
468473
}
469474
}
475+
if (gate_up_weights_scales) {
476+
ct.AddN(gate_up_weights_scales.get());
477+
}
478+
if (down_weights_scales) {
479+
ct.AddN(down_weights_scales.get());
480+
}
470481

471482
ct.Add(*final_hidden_states, false);
472483

@@ -500,10 +511,13 @@ void CallFusedGateMoeKernel(
500511
const phi::DenseTensor& hidden_states,
501512
const phi::DenseTensor& gate_weights,
502513
const paddle::optional<phi::DenseTensor>& gate_correction_bias,
503-
const std::vector<phi::DenseTensor>& gate_up_weights,
504-
const std::vector<phi::DenseTensor>& down_weights,
514+
const phi::DenseTensor& gate_up_weights,
515+
const phi::DenseTensor& down_weights,
505516
const paddle::optional<phi::DenseTensor>& hidden_states_scales,
506-
const paddle::optional<std::vector<phi::DenseTensor>>& scales,
517+
const paddle::optional<std::vector<phi::DenseTensor>>&
518+
intermediate_hidden_states_scales,
519+
const paddle::optional<phi::DenseTensor>& gate_up_weights_scales,
520+
const paddle::optional<phi::DenseTensor>& down_weights_scales,
507521
phi::DenseTensor* final_hidden_states,
508522
const int top_k,
509523
const bool norm_topk_prob,
@@ -528,7 +542,9 @@ void CallFusedGateMoeKernel(
528542
gate_up_weights,
529543
down_weights,
530544
hidden_states_scales,
531-
scales,
545+
intermediate_hidden_states_scales,
546+
gate_up_weights_scales,
547+
down_weights_scales,
532548
final_hidden_states,
533549
top_k,
534550
norm_topk_prob,
@@ -550,7 +566,9 @@ void CallFusedGateMoeKernel(
550566
gate_up_weights,
551567
down_weights,
552568
hidden_states_scales,
553-
scales,
569+
intermediate_hidden_states_scales,
570+
gate_up_weights_scales,
571+
down_weights_scales,
554572
final_hidden_states,
555573
top_k,
556574
norm_topk_prob,
@@ -572,8 +590,8 @@ std::vector<paddle::Tensor> FusedGateMoeForward(
572590
const paddle::Tensor& hidden_states,
573591
const paddle::Tensor& gate_weights,
574592
const paddle::optional<paddle::Tensor>& gate_correction_bias,
575-
const std::vector<paddle::Tensor>& gate_up_weights,
576-
const std::vector<paddle::Tensor>& down_weights,
593+
const paddle::Tensor& gate_up_weights,
594+
const paddle::Tensor& down_weights,
577595
const int top_k,
578596
const bool norm_topk_prob,
579597
const bool permuted_weights,
@@ -598,16 +616,10 @@ std::vector<paddle::Tensor> FusedGateMoeForward(
598616
paddle::optional<phi::DenseTensor>(*gate_correction_bias_dt);
599617
}
600618

601-
std::vector<phi::DenseTensor> gate_up_weights_vec;
602-
for (const auto& t : gate_up_weights) {
603-
gate_up_weights_vec.push_back(
604-
*static_cast<const phi::DenseTensor*>(t.impl().get()));
605-
}
606-
std::vector<phi::DenseTensor> down_weights_vec;
607-
for (const auto& t : down_weights) {
608-
down_weights_vec.push_back(
609-
*static_cast<const phi::DenseTensor*>(t.impl().get()));
610-
}
619+
auto gate_up_weights_tensor =
620+
static_cast<const phi::DenseTensor*>(gate_up_weights.impl().get());
621+
auto down_weights_tensor =
622+
static_cast<const phi::DenseTensor*>(down_weights.impl().get());
611623

612624
std::shared_ptr<phi::DenseTensor> final_hidden_states =
613625
std::make_shared<phi::DenseTensor>();
@@ -619,10 +631,12 @@ std::vector<paddle::Tensor> FusedGateMoeForward(
619631
*hidden_states_tensor,
620632
*gate_weights_tensor,
621633
gate_correction_tensor,
622-
gate_up_weights_vec,
623-
down_weights_vec,
634+
*gate_up_weights_tensor,
635+
*down_weights_tensor,
624636
paddle::optional<phi::DenseTensor>(), /* hidden_states_scale */
625-
paddle::optional<std::vector<phi::DenseTensor>>(), /* scales */
637+
paddle::optional<std::vector<phi::DenseTensor>>(), /* intermediate */
638+
paddle::optional<phi::DenseTensor>(), /* gate_up_weights_scales */
639+
paddle::optional<phi::DenseTensor>(), /* down_weights_scales */
626640
final_hidden_states.get(),
627641
top_k,
628642
norm_topk_prob,
@@ -643,13 +657,13 @@ std::vector<paddle::Tensor> FusedGateMoeFP8Forward(
643657
const paddle::Tensor& hidden_states,
644658
const paddle::Tensor& gate_weights,
645659
const paddle::optional<paddle::Tensor>& gate_correction_bias,
646-
const std::vector<paddle::Tensor>& gate_up_weights,
647-
const std::vector<paddle::Tensor>& down_weights,
660+
const paddle::Tensor& gate_up_weights,
661+
const paddle::Tensor& down_weights,
648662
const paddle::optional<paddle::Tensor>& hidden_states_scales,
649663
const paddle::optional<std::vector<paddle::Tensor>>&
650664
intermediate_hidden_states_scales,
651-
const std::vector<paddle::Tensor>& gate_up_weights_scales,
652-
const std::vector<paddle::Tensor>& down_weights_scales,
665+
const paddle::Tensor& gate_up_weights_scales,
666+
const paddle::Tensor& down_weights_scales,
653667
const int top_k,
654668
const bool norm_topk_prob,
655669
const bool permuted_weights,
@@ -674,16 +688,10 @@ std::vector<paddle::Tensor> FusedGateMoeFP8Forward(
674688
paddle::optional<phi::DenseTensor>(*gate_correction_bias_dt);
675689
}
676690

677-
std::vector<phi::DenseTensor> gate_up_weights_vec;
678-
for (const auto& t : gate_up_weights) {
679-
gate_up_weights_vec.push_back(
680-
*static_cast<const phi::DenseTensor*>(t.impl().get()));
681-
}
682-
std::vector<phi::DenseTensor> down_weights_vec;
683-
for (const auto& t : down_weights) {
684-
down_weights_vec.push_back(
685-
*static_cast<const phi::DenseTensor*>(t.impl().get()));
686-
}
691+
auto gate_up_weights_tensor =
692+
static_cast<const phi::DenseTensor*>(gate_up_weights.impl().get());
693+
auto down_weights_tensor =
694+
static_cast<const phi::DenseTensor*>(down_weights.impl().get());
687695

688696
auto hidden_states_scales_tensor = paddle::optional<phi::DenseTensor>();
689697
if (hidden_states_scales) {
@@ -702,12 +710,17 @@ std::vector<paddle::Tensor> FusedGateMoeFP8Forward(
702710
*static_cast<const phi::DenseTensor*>(t.impl().get()));
703711
}
704712
}
705-
for (const auto& t : gate_up_weights_scales) {
706-
scales_vec.push_back(*static_cast<const phi::DenseTensor*>(t.impl().get()));
707-
}
708-
for (const auto& t : down_weights_scales) {
709-
scales_vec.push_back(*static_cast<const phi::DenseTensor*>(t.impl().get()));
710-
}
713+
auto gate_up_weights_scales_tensor = paddle::optional<phi::DenseTensor>();
714+
auto gate_up_weights_scales_dt =
715+
static_cast<const phi::DenseTensor*>(gate_up_weights_scales.impl().get());
716+
gate_up_weights_scales_tensor =
717+
paddle::optional<phi::DenseTensor>(*gate_up_weights_scales_dt);
718+
719+
auto down_weights_scales_tensor = paddle::optional<phi::DenseTensor>();
720+
auto down_weights_scales_dt =
721+
static_cast<const phi::DenseTensor*>(down_weights_scales.impl().get());
722+
down_weights_scales_tensor =
723+
paddle::optional<phi::DenseTensor>(*down_weights_scales_dt);
711724

712725
std::shared_ptr<phi::DenseTensor> final_hidden_states =
713726
std::make_shared<phi::DenseTensor>();
@@ -719,10 +732,12 @@ std::vector<paddle::Tensor> FusedGateMoeFP8Forward(
719732
*hidden_states_tensor,
720733
*gate_weights_tensor,
721734
gate_correction_tensor,
722-
gate_up_weights_vec,
723-
down_weights_vec,
735+
*gate_up_weights_tensor,
736+
*down_weights_tensor,
724737
hidden_states_scales_tensor,
725738
scales_vec,
739+
gate_up_weights_scales_tensor,
740+
down_weights_scales_tensor,
726741
final_hidden_states.get(),
727742
top_k,
728743
norm_topk_prob,
@@ -742,10 +757,10 @@ std::vector<paddle::Tensor> FusedGateMoeBlockWiseFP8Forward(
742757
const paddle::Tensor& hidden_states,
743758
const paddle::Tensor& gate_weights,
744759
const paddle::optional<paddle::Tensor>& gate_correction_bias,
745-
const std::vector<paddle::Tensor>& gate_up_weights,
746-
const std::vector<paddle::Tensor>& down_weights,
747-
const std::vector<paddle::Tensor>& gate_up_weights_scales,
748-
const std::vector<paddle::Tensor>& down_weights_scales,
760+
const paddle::Tensor& gate_up_weights,
761+
const paddle::Tensor& down_weights,
762+
const paddle::Tensor& gate_up_weights_scales,
763+
const paddle::Tensor& down_weights_scales,
749764
const int top_k,
750765
const bool norm_topk_prob,
751766
const bool permuted_weights,
@@ -771,24 +786,22 @@ std::vector<paddle::Tensor> FusedGateMoeBlockWiseFP8Forward(
771786
paddle::optional<phi::DenseTensor>(*gate_correction_bias_dt);
772787
}
773788

774-
std::vector<phi::DenseTensor> gate_up_weights_vec;
775-
for (const auto& t : gate_up_weights) {
776-
gate_up_weights_vec.push_back(
777-
*static_cast<const phi::DenseTensor*>(t.impl().get()));
778-
}
779-
std::vector<phi::DenseTensor> down_weights_vec;
780-
for (const auto& t : down_weights) {
781-
down_weights_vec.push_back(
782-
*static_cast<const phi::DenseTensor*>(t.impl().get()));
783-
}
789+
auto gate_up_weights_tensor =
790+
static_cast<const phi::DenseTensor*>(gate_up_weights.impl().get());
791+
auto down_weights_tensor =
792+
static_cast<const phi::DenseTensor*>(down_weights.impl().get());
784793

785-
std::vector<phi::DenseTensor> scales_vec;
786-
for (const auto& t : gate_up_weights_scales) {
787-
scales_vec.push_back(*static_cast<const phi::DenseTensor*>(t.impl().get()));
788-
}
789-
for (const auto& t : down_weights_scales) {
790-
scales_vec.push_back(*static_cast<const phi::DenseTensor*>(t.impl().get()));
791-
}
794+
auto gate_up_weights_scales_tensor = paddle::optional<phi::DenseTensor>();
795+
auto gate_up_weights_scales_dt =
796+
static_cast<const phi::DenseTensor*>(gate_up_weights_scales.impl().get());
797+
gate_up_weights_scales_tensor =
798+
paddle::optional<phi::DenseTensor>(*gate_up_weights_scales_dt);
799+
800+
auto down_weights_scales_tensor = paddle::optional<phi::DenseTensor>();
801+
auto down_weights_scales_dt =
802+
static_cast<const phi::DenseTensor*>(down_weights_scales.impl().get());
803+
down_weights_scales_tensor =
804+
paddle::optional<phi::DenseTensor>(*down_weights_scales_dt);
792805

793806
std::shared_ptr<phi::DenseTensor> final_hidden_states =
794807
std::make_shared<phi::DenseTensor>();
@@ -800,10 +813,12 @@ std::vector<paddle::Tensor> FusedGateMoeBlockWiseFP8Forward(
800813
*hidden_states_tensor,
801814
*gate_weights_tensor,
802815
gate_correction_tensor,
803-
gate_up_weights_vec,
804-
down_weights_vec,
816+
*gate_up_weights_tensor,
817+
*down_weights_tensor,
805818
paddle::optional<phi::DenseTensor>(), /* hidden_states_scale */
806-
scales_vec,
819+
paddle::optional<std::vector<phi::DenseTensor>>(), /* intermediate */
820+
gate_up_weights_scales_tensor,
821+
down_weights_scales_tensor,
807822
final_hidden_states.get(),
808823
top_k,
809824
norm_topk_prob,
@@ -845,8 +860,8 @@ PD_BUILD_OP(fused_gate_moe)
845860
.Inputs({"hidden_states",
846861
"gate_weights",
847862
paddle::Optional("gate_correction_bias"),
848-
paddle::Vec("gate_up_weights"),
849-
paddle::Vec("down_weights")})
863+
"gate_up_weights",
864+
"down_weights"})
850865
.Outputs({"final_hidden_states"})
851866
.Attrs({"top_k: int",
852867
"norm_topk_prob: bool",
@@ -869,12 +884,12 @@ PD_BUILD_OP(fused_gate_moe_fp8)
869884
.Inputs({"hidden_states",
870885
"gate_weights",
871886
paddle::Optional("gate_correction_bias"),
872-
paddle::Vec("gate_up_weights"),
873-
paddle::Vec("down_weights"),
887+
"gate_up_weights",
888+
"down_weights",
874889
paddle::Optional("hidden_states_scales"),
875890
paddle::Optional(paddle::Vec("intermediate_hidden_states_scales")),
876-
paddle::Vec("gate_up_weights_scales"),
877-
paddle::Vec("down_weights_scales")})
891+
"gate_up_weights_scales",
892+
"down_weights_scales"})
878893
.Outputs({"final_hidden_states"})
879894
.Attrs({"top_k: int",
880895
"norm_topk_prob: bool",
@@ -896,10 +911,10 @@ PD_BUILD_OP(fused_gate_moe_blockwise_fp8)
896911
.Inputs({"hidden_states",
897912
"gate_weights",
898913
paddle::Optional("gate_correction_bias"),
899-
paddle::Vec("gate_up_weights"),
900-
paddle::Vec("down_weights"),
901-
paddle::Vec("gate_up_weights_scales"),
902-
paddle::Vec("down_weights_scales")})
914+
"gate_up_weights",
915+
"down_weights",
916+
"gate_up_weights_scales",
917+
"down_weights_scales"})
903918
.Outputs({"final_hidden_states"})
904919
.Attrs({"top_k: int",
905920
"norm_topk_prob: bool",

backends/intel_hpu/custom_ops/python/paddlenlp_ops/Model_convert.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def save_tail_tensors_and_index(
8787
def tensorwise_quant_to_fp8(tensor):
8888
x_abs = paddle.abs(tensor).astype(paddle.float32)
8989
x_amax = paddle.amax(x_abs)
90-
x_amax = paddle.clip(x_amax, min=1e-4)
90+
x_amax = paddle.clip(x_amax, min=1e-8)
9191
scale = x_amax / 240.0
9292
x_scaled = (tensor.cast("float32") / scale).cast("float8_e4m3fn").clone()
9393

@@ -96,6 +96,19 @@ def tensorwise_quant_to_fp8(tensor):
9696
)
9797

9898

99+
def channelwise_quant_to_fp8(tensor):
100+
# Channel-wise quantization along the last dimension (N)
101+
x_abs = paddle.abs(tensor).astype(paddle.float32)
102+
x_amax = paddle.amax(x_abs, axis=0) # shape: [N]
103+
x_amax = paddle.clip(x_amax, min=1e-8)
104+
scale = x_amax / 240.0 # shape: [N]
105+
x_scaled = (
106+
(tensor.cast("float32") / scale.cast("float32")).cast("float8_e4m3fn").clone()
107+
)
108+
109+
return paddle.view(x_scaled, "int8").clone(), scale.cast("bfloat16").clone()
110+
111+
99112
def process_safetensors_file(
100113
tensors_dict,
101114
src_path,
@@ -118,7 +131,10 @@ def process_safetensors_file(
118131
continue
119132
else:
120133
tensor = paddle.Tensor(tensor, zero_copy=True)
121-
quant_tensor, scale = tensorwise_quant_to_fp8(tensor)
134+
if ".experts." in key: # except for shared_experts
135+
quant_tensor, scale = channelwise_quant_to_fp8(tensor)
136+
else:
137+
quant_tensor, scale = tensorwise_quant_to_fp8(tensor)
122138

123139
t_size = tensor_size(quant_tensor) + tensor_size(scale)
124140
if current_size + t_size > max_size_bytes and tensors_dict:

0 commit comments

Comments
 (0)