@@ -601,34 +601,35 @@ def __init__(
601
601
super ().__init__ (dim , max_position_embeddings , base )
602
602
603
603
def _set_cos_sin_cache (self , seq_len ):
604
- self .max_seq_len_cached = seq_len
605
- dim = self .dim
606
-
607
- freq_extra = 1.0 / (self .base ** (paddle .arange (0 , dim , 2 , dtype = paddle .float32 ) / dim ))
608
- freq_inter = 1.0 / (self .scaling_factor * self .base ** (paddle .arange (0 , dim , 2 , dtype = paddle .float32 ) / dim ))
609
-
610
- low , high = yarn_find_correction_range (
611
- self .beta_fast ,
612
- self .beta_slow ,
613
- dim ,
614
- self .base ,
615
- self .original_max_position_embeddings ,
616
- )
617
- inv_freq_mask = 1.0 - yarn_linear_ramp_mask (low , high , dim // 2 )
618
- self .inv_freq = freq_inter * (1 - inv_freq_mask ) + freq_extra * inv_freq_mask
604
+ with paddle .amp .auto_cast (False ):
605
+ self .max_seq_len_cached = seq_len
606
+ dim = self .dim
607
+
608
+ freq_extra = 1.0 / (self .base ** (paddle .arange (0 , dim , 2 , dtype = paddle .float32 ) / dim ))
609
+ freq_inter = 1.0 / (self .scaling_factor * self .base ** (paddle .arange (0 , dim , 2 , dtype = paddle .float32 ) / dim ))
610
+
611
+ low , high = yarn_find_correction_range (
612
+ self .beta_fast ,
613
+ self .beta_slow ,
614
+ dim ,
615
+ self .base ,
616
+ self .original_max_position_embeddings ,
617
+ )
618
+ inv_freq_mask = 1.0 - yarn_linear_ramp_mask (low , high , dim // 2 )
619
+ self .inv_freq = freq_inter * (1 - inv_freq_mask ) + freq_extra * inv_freq_mask
619
620
620
- t = paddle .arange (seq_len , dtype = paddle .float32 )
621
+ t = paddle .arange (seq_len , dtype = paddle .float32 )
621
622
622
- freqs = paddle .outer (t , paddle .cast (self .inv_freq , dtype = "float32" ))
623
+ freqs = paddle .outer (t , paddle .cast (self .inv_freq , dtype = "float32" ))
623
624
624
- _mscale = float (
625
- yarn_get_mscale (self .scaling_factor , self .mscale )
626
- / yarn_get_mscale (self .scaling_factor , self .mscale_all_dim )
627
- )
625
+ _mscale = float (
626
+ yarn_get_mscale (self .scaling_factor , self .mscale )
627
+ / yarn_get_mscale (self .scaling_factor , self .mscale_all_dim )
628
+ )
628
629
629
- emb = paddle .concat ((freqs , freqs ), axis = - 1 )
630
- self .cos_cached = emb .cos () * _mscale
631
- self .sin_cached = emb .sin () * _mscale
630
+ emb = paddle .concat ((freqs , freqs ), axis = - 1 )
631
+ self .cos_cached = emb .cos () * _mscale
632
+ self .sin_cached = emb .sin () * _mscale
632
633
633
634
634
635
def rotate_half (x ):
0 commit comments