5959)
6060from paddlenlp .transformers .moe_layer import FusionMoeNode
6161
62- from ..fp8_utils import (
63- fp8_mlp_bwd ,
64- fp8_mlp_bwd_norm_rc ,
65- fp8_mlp_fwd ,
66- fp8_mlp_fwd_norm_rc ,
67- )
62+ from ..fp8_utils import FP8LinearFunctionBase
6863
6964__all__ = [
7065 "DeepseekV2ForCausalLMPipe" ,
@@ -175,15 +170,17 @@ def forward(self, inputs):
175170 with paddle .no_grad ():
176171 if self .shared_experts is not None :
177172 if self .using_post_norm_recompute :
178- shared_expert_output = fp8_mlp_fwd_norm_rc (
173+ shared_expert_output = FP8LinearFunctionBase . fp8_mlp_fwd_norm_rc (
179174 hidden_states ,
180175 self .shared_experts .norm_weight ,
181176 self .shared_experts .norm_eps ,
182177 self .shared_experts .w1 ,
183178 self .shared_experts .w2 ,
184179 )
185180 else :
186- shared_expert_output = fp8_mlp_fwd (hidden_states , self .shared_experts .w1 , self .shared_experts .w2 )
181+ _ , _ , shared_expert_output = FP8LinearFunctionBase .fp8_mlp_fwd (
182+ hidden_states , self .shared_experts .w1 , self .shared_experts .w2
183+ )
187184 final_hidden_states = final_hidden_states + shared_expert_output
188185
189186 self .x = hidden_states
@@ -201,7 +198,7 @@ def backward(self, output_grad):
201198
202199 assert not self .send_mtp_embed , "not support have mtp have yet"
203200 if self .using_post_norm_recompute :
204- dx = fp8_mlp_bwd_norm_rc (
201+ dx = FP8LinearFunctionBase . fp8_mlp_bwd_norm_rc (
205202 do3 ,
206203 self .x ,
207204 self .shared_experts .norm_weight ,
@@ -210,7 +207,7 @@ def backward(self, output_grad):
210207 self .shared_experts .w2 ,
211208 )
212209 else :
213- dx = fp8_mlp_bwd (do3 , self .x , self .shared_experts .w1 , self .shared_experts .w2 )
210+ dx = FP8LinearFunctionBase . fp8_mlp_bwd (do3 , self .x , self .shared_experts .w1 , self .shared_experts .w2 )
214211
215212 self .x = None
216213
0 commit comments