@@ -79,25 +79,27 @@ def group_gemm(
7979def iluvatar_moe_expert_ffn (
8080 permute_input : paddle .Tensor ,
8181 tokens_expert_prefix_sum : paddle .Tensor ,
82- ffn1_weight : paddle .Tensor ,
83- ffn2_weight : paddle .Tensor ,
84- ffn1_bias : Optional [paddle .Tensor ],
85- ffn1_scale : Optional [paddle .Tensor ],
86- ffn2_scale : Optional [paddle .Tensor ],
87- ffn2_in_scale : Optional [paddle .Tensor ],
82+ up_gate_proj_weight : paddle .Tensor ,
83+ down_proj_weight : paddle .Tensor ,
84+ up_gate_proj_bias : Optional [paddle .Tensor ],
85+ up_gate_proj_scale : Optional [paddle .Tensor ],
86+ down_proj_scale : Optional [paddle .Tensor ],
87+ down_proj_in_scale : Optional [paddle .Tensor ],
8888 expert_idx_per_token : Optional [paddle .Tensor ],
8989 quant_method : str ,
9090 used_in_ep_low_latency : bool ,
9191):
92- assert ffn1_bias is None
93- assert ffn1_scale is not None
94- assert ffn2_scale is not None
95- assert ffn2_in_scale is None
92+ assert up_gate_proj_bias is None
93+ assert up_gate_proj_scale is not None
94+ assert down_proj_scale is not None
95+ assert down_proj_in_scale is None
9696 assert expert_idx_per_token is None
9797 assert quant_method in ("weight_only_int8" )
9898 assert not used_in_ep_low_latency
9999 tokens_expert_prefix_sum_cpu = tokens_expert_prefix_sum .to ("cpu" )
100- ffn1_output = w8a16_group_gemm (permute_input , ffn1_weight , ffn1_scale , tokens_expert_prefix_sum_cpu , - 1 )
100+ ffn1_output = w8a16_group_gemm (
101+ permute_input , up_gate_proj_weight , up_gate_proj_scale , tokens_expert_prefix_sum_cpu , - 1
102+ )
101103 act_out = swiglu (ffn1_output )
102- output = w8a16_group_gemm (act_out , ffn2_weight , ffn2_scale , tokens_expert_prefix_sum_cpu , - 1 )
104+ output = w8a16_group_gemm (act_out , down_proj_weight , down_proj_scale , tokens_expert_prefix_sum_cpu , - 1 )
103105 return output
0 commit comments