@@ -195,6 +195,8 @@ def forward_without_residual(self, inputs):
195
195
_ , _ , shared_expert_output = FP8LinearFunctionBase .fp8_mlp_fwd (
196
196
norm_output , self .shared_experts .w1 , self .shared_experts .w2
197
197
)
198
+ norm_output = None
199
+ del norm_output
198
200
else :
199
201
_ , _ , shared_expert_output = FP8LinearFunctionBase .fp8_mlp_fwd (
200
202
hidden_states , self .shared_experts .w1 , self .shared_experts .w2
@@ -226,13 +228,19 @@ def forward(self, inputs):
226
228
with paddle .no_grad ():
227
229
if self .shared_experts is not None :
228
230
if self .using_post_norm_recompute :
229
- shared_expert_output = FP8LinearFunctionBase .fp8_mlp_fwd_norm_rc (
230
- hidden_states ,
231
- self .shared_experts .norm_weight ,
232
- self .shared_experts .norm_eps ,
233
- self .shared_experts .w1 ,
234
- self .shared_experts .w2 ,
231
+ global norm_out
232
+ # shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd_norm_rc(
233
+ # hidden_states,
234
+ # self.shared_experts.norm_weight,
235
+ # self.shared_experts.norm_eps,
236
+ # self.shared_experts.w1,
237
+ # self.shared_experts.w2,
238
+ # )
239
+ _ , _ , shared_expert_output = FP8LinearFunctionBase .fp8_mlp_fwd (
240
+ norm_output , self .shared_experts .w1 , self .shared_experts .w2
235
241
)
242
+ norm_output = None
243
+ del norm_output
236
244
else :
237
245
_ , _ , shared_expert_output = FP8LinearFunctionBase .fp8_mlp_fwd (
238
246
hidden_states , self .shared_experts .w1 , self .shared_experts .w2
0 commit comments