Skip to content

Commit fb22a9f

Browse files
authored
yarn rope use fp32 (#10973)
1 parent 0efb202 commit fb22a9f

File tree

1 file changed

+25
-24
lines changed

1 file changed

+25
-24
lines changed

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -601,34 +601,35 @@ def __init__(
601601
super().__init__(dim, max_position_embeddings, base)
602602

603603
def _set_cos_sin_cache(self, seq_len):
604-
self.max_seq_len_cached = seq_len
605-
dim = self.dim
606-
607-
freq_extra = 1.0 / (self.base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim))
608-
freq_inter = 1.0 / (self.scaling_factor * self.base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim))
609-
610-
low, high = yarn_find_correction_range(
611-
self.beta_fast,
612-
self.beta_slow,
613-
dim,
614-
self.base,
615-
self.original_max_position_embeddings,
616-
)
617-
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
618-
self.inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
604+
with paddle.amp.auto_cast(False):
605+
self.max_seq_len_cached = seq_len
606+
dim = self.dim
607+
608+
freq_extra = 1.0 / (self.base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim))
609+
freq_inter = 1.0 / (self.scaling_factor * self.base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim))
610+
611+
low, high = yarn_find_correction_range(
612+
self.beta_fast,
613+
self.beta_slow,
614+
dim,
615+
self.base,
616+
self.original_max_position_embeddings,
617+
)
618+
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
619+
self.inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
619620

620-
t = paddle.arange(seq_len, dtype=paddle.float32)
621+
t = paddle.arange(seq_len, dtype=paddle.float32)
621622

622-
freqs = paddle.outer(t, paddle.cast(self.inv_freq, dtype="float32"))
623+
freqs = paddle.outer(t, paddle.cast(self.inv_freq, dtype="float32"))
623624

624-
_mscale = float(
625-
yarn_get_mscale(self.scaling_factor, self.mscale)
626-
/ yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
627-
)
625+
_mscale = float(
626+
yarn_get_mscale(self.scaling_factor, self.mscale)
627+
/ yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
628+
)
628629

629-
emb = paddle.concat((freqs, freqs), axis=-1)
630-
self.cos_cached = emb.cos() * _mscale
631-
self.sin_cached = emb.sin() * _mscale
630+
emb = paddle.concat((freqs, freqs), axis=-1)
631+
self.cos_cached = emb.cos() * _mscale
632+
self.sin_cached = emb.sin() * _mscale
632633

633634

634635
def rotate_half(x):

0 commit comments

Comments
 (0)