Skip to content

Commit 831d818

Browse files
authored
add recompute of postnorm in pp (#10849)
1 parent c52eb6f commit 831d818

File tree

3 files changed

+259
-378
lines changed

3 files changed

+259
-378
lines changed

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2040,7 +2040,9 @@ def self_attn_compute(self, hidden_states, **kwargs):
20402040
hidden_states = residual + hidden_states
20412041

20422042
residual = hidden_states
2043-
hidden_states = self.post_attention_layernorm(hidden_states)
2043+
2044+
if not self.using_post_norm_recompute:
2045+
hidden_states = self.post_attention_layernorm(hidden_states)
20442046

20452047
return hidden_states, residual
20462048

0 commit comments

Comments
 (0)