@@ -1608,15 +1608,51 @@ def _phi3_self_attn_sdpa_forward(
16081608    return  attn_output , None , past_key_value 
16091609
16101610
1611+ @torch .jit .script  
1612+ def  select_ext_factor (seq_len : torch .Tensor , max_pos_embeddings : torch .Tensor , short_factor : torch .Tensor , long_factor : torch .Tensor ):
1613+     if  seq_len  >  max_pos_embeddings :
1614+         return  long_factor 
1615+     return  short_factor 
1616+ 
1617+ def  long_rope (self , x , position_ids , seq_len = None ):
1618+     seq_len  =  torch .max (position_ids ) +  1 
1619+     original_max_position_embeddings  =  (
1620+         self .original_max_position_embeddings 
1621+         if  hasattr (self , "original_max_positional_embeddings" ) else  self .config .original_max_position_embeddings 
1622+     )
1623+     max_position_embeddings  =  self .max_position_embeddings  if  hasattr (self , "max_position_embeddings" ) else  self .config .max_position_embeddings 
1624+     inv_freq  =  select_ext_factor (
1625+         seq_len ,
1626+         torch .tensor (original_max_position_embeddings ),
1627+         self .inv_freq ,
1628+         self .long_inv_freq 
1629+     )
1630+ 
1631+     inv_freq_expanded  =  inv_freq [None , :, None ].float ().expand (position_ids .shape [0 ], - 1 , 1 )
1632+     position_ids_expanded  =  position_ids [:, None , :].float ()
1633+ 
1634+     # Force float32 since bfloat16 loses precision on long contexts 
1635+     # See https://github.com/huggingface/transformers/pull/29285 
1636+     device_type  =  x .device .type 
1637+     device_type  =  device_type  if  isinstance (device_type , str ) and  device_type  !=  "mps"  else  "cpu" 
1638+     freqs  =  (inv_freq_expanded .float () @ position_ids_expanded .float ()).transpose (1 , 2 )
1639+     emb  =  torch .cat ((freqs , freqs ), dim = - 1 )
1640+ 
1641+     scale  =  max_position_embeddings  /  original_max_position_embeddings 
1642+     if  scale  <=  1.0 :
1643+         scaling_factor  =  1.0 
1644+     else :
1645+         scaling_factor  =  math .sqrt (1  +  math .log (scale ) /  math .log (original_max_position_embeddings ))
1646+     cos  =  emb .cos () *  scaling_factor 
1647+     sin  =  emb .sin () *  scaling_factor 
1648+     return  cos , sin 
1649+ 
1650+ 
16111651class  Phi3ModelPatcher (DecoderModelPatcher ):
16121652    def  __enter__ (self ):
16131653        super ().__enter__ ()
16141654
16151655        # currently, long RoPE can not be traced for long context support, disable it for avoid potential accuracy issues 
1616-         if  self ._model .config .max_position_embeddings  !=  getattr (
1617-             self ._model .config , "original_max_position_embeddings" , self ._model .config .max_position_embeddings 
1618-         ):
1619-             self ._model .config .max_position_embeddings  =  self ._model .config .original_max_position_embeddings 
16201656
16211657        if  is_transformers_version (">=" , "4.42.0" ) and  is_transformers_version ("<" , "4.48.0" ):
16221658            self ._model .model ._orig_forward  =  self ._model .model .forward 
@@ -1643,6 +1679,17 @@ def __enter__(self):
16431679                layer .self_attn .rotary_emb .inv_freq  =  1.0  /  (
16441680                    rotary_emb .base  **  (torch .arange (0 , rotary_emb .dim , 2 , dtype = torch .int64 ).float () /  rotary_emb .dim )
16451681                )
1682+         
1683+         if  hasattr (self ._model .model , "rotary_emb" ) and  getattr (self ._model .model .rotary_emb , "rope_type" , "default" ) ==  "longrope" :
1684+             long_inv_freq , _  =  self ._model .model .rotary_emb .rope_init_fn (self ._model .config , torch .device ("cpu" ), seq_len = self ._model .config .original_max_position_embeddings  +  1 )
1685+             self ._model .model .rotary_emb .long_inv_freq  =  long_inv_freq 
1686+             self ._model .model .rotary_emb ._orig_forward  =  self ._model .model .rotary_emb .forward 
1687+             self ._model .model .rotary_emb .forward  =  types .MethodType (long_rope , self ._model .model .rotary_emb )
1688+         elif  self ._model .config .max_position_embeddings  !=  getattr (
1689+             self ._model .config , "original_max_position_embeddings" , self ._model .config .max_position_embeddings 
1690+         ):
1691+             self ._model .config .max_position_embeddings  =  self ._model .config .original_max_position_embeddings 
1692+ 
16461693
16471694    def  __exit__ (self , exc_type , exc_value , traceback ):
16481695        super ().__exit__ (exc_type , exc_value , traceback )
@@ -1653,6 +1700,8 @@ def __exit__(self, exc_type, exc_value, traceback):
16531700        for  layer  in  self ._model .model .layers :
16541701            if  hasattr (layer .self_attn , "_orig_forward" ):
16551702                layer .self_attn .forward  =  layer .self_attn ._orig_forward 
1703+         if  hasattr (self ._model .model , "rotary_emb" ) and  hasattr (self ._model .model .rotary_emb , "_orig_forward" ):
1704+             self ._model .model .rotary_emb .forward  =  self ._model .model .rotary_emb ._orig_forward 
16561705
16571706
16581707# Modified from https://github.com/huggingface/transformers/blob/v4.50.2/src/transformers/models/phimoe/modeling_phimoe.py#L756 
0 commit comments