Skip to content

Commit 2110d10

Browse files
committed
update
1 parent 323c590 commit 2110d10

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ def forward_without_residual(self, inputs):
195195
_, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd(
196196
norm_output, self.shared_experts.w1, self.shared_experts.w2
197197
)
198+
norm_output = None
199+
del norm_output
198200
else:
199201
_, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd(
200202
hidden_states, self.shared_experts.w1, self.shared_experts.w2
@@ -226,13 +228,19 @@ def forward(self, inputs):
226228
with paddle.no_grad():
227229
if self.shared_experts is not None:
228230
if self.using_post_norm_recompute:
229-
shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd_norm_rc(
230-
hidden_states,
231-
self.shared_experts.norm_weight,
232-
self.shared_experts.norm_eps,
233-
self.shared_experts.w1,
234-
self.shared_experts.w2,
231+
global norm_out
232+
# shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd_norm_rc(
233+
# hidden_states,
234+
# self.shared_experts.norm_weight,
235+
# self.shared_experts.norm_eps,
236+
# self.shared_experts.w1,
237+
# self.shared_experts.w2,
238+
# )
239+
_, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd(
240+
norm_output, self.shared_experts.w1, self.shared_experts.w2
235241
)
242+
norm_output = None
243+
del norm_output
236244
else:
237245
_, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd(
238246
hidden_states, self.shared_experts.w1, self.shared_experts.w2

0 commit comments

Comments
 (0)