@@ -1608,25 +1608,28 @@ def _phi3_self_attn_sdpa_forward(
16081608 return attn_output , None , past_key_value
16091609
16101610
1611- @torch .jit .script
1612- def select_ext_factor (seq_len : torch .Tensor , max_pos_embeddings : torch .Tensor , short_factor : torch .Tensor , long_factor : torch .Tensor ):
1613- if seq_len > max_pos_embeddings :
1614- return long_factor
1615- return short_factor
1611+ # @torch.jit.script
1612+ def select_ext_factor (
1613+ seq_len : torch .Tensor , max_pos_embeddings : torch .Tensor , short_factor : torch .Tensor , long_factor : torch .Tensor
1614+ ):
1615+ return torch .where (
1616+ seq_len < max_pos_embeddings , short_factor , long_factor
1617+ ) # short_factor * (seq_len <= max_pos_embeddings) + long_factor * (seq_len > max_pos_embeddings)
1618+
16161619
16171620def long_rope (self , x , position_ids , seq_len = None ):
16181621 seq_len = torch .max (position_ids ) + 1
16191622 original_max_position_embeddings = (
16201623 self .original_max_position_embeddings
1621- if hasattr (self , "original_max_positional_embeddings" ) else self .config .original_max_position_embeddings
1624+ if hasattr (self , "original_max_positional_embeddings" )
1625+ else self .config .original_max_position_embeddings
16221626 )
1623- max_position_embeddings = self .max_position_embeddings if hasattr (self , "max_position_embeddings" ) else self .config .max_position_embeddings
1624- inv_freq = select_ext_factor (
1625- seq_len ,
1626- torch .tensor (original_max_position_embeddings ),
1627- self .inv_freq ,
1628- self .long_inv_freq
1627+ max_position_embeddings = (
1628+ self .max_position_embeddings
1629+ if hasattr (self , "max_position_embeddings" )
1630+ else self .config .max_position_embeddings
16291631 )
1632+ inv_freq = select_ext_factor (seq_len , original_max_position_embeddings , self .inv_freq , self .long_inv_freq )
16301633
16311634 inv_freq_expanded = inv_freq [None , :, None ].float ().expand (position_ids .shape [0 ], - 1 , 1 )
16321635 position_ids_expanded = position_ids [:, None , :].float ()
@@ -1679,9 +1682,16 @@ def __enter__(self):
16791682 layer .self_attn .rotary_emb .inv_freq = 1.0 / (
16801683 rotary_emb .base ** (torch .arange (0 , rotary_emb .dim , 2 , dtype = torch .int64 ).float () / rotary_emb .dim )
16811684 )
1682-
1683- if hasattr (self ._model .model , "rotary_emb" ) and getattr (self ._model .model .rotary_emb , "rope_type" , "default" ) == "longrope" :
1684- long_inv_freq , _ = self ._model .model .rotary_emb .rope_init_fn (self ._model .config , torch .device ("cpu" ), seq_len = self ._model .config .original_max_position_embeddings + 1 )
1685+
1686+ if (
1687+ hasattr (self ._model .model , "rotary_emb" )
1688+ and getattr (self ._model .model .rotary_emb , "rope_type" , "default" ) == "longrope"
1689+ ):
1690+ long_inv_freq , _ = self ._model .model .rotary_emb .rope_init_fn (
1691+ self ._model .config ,
1692+ torch .device ("cpu" ),
1693+ seq_len = self ._model .config .original_max_position_embeddings + 1 ,
1694+ )
16851695 self ._model .model .rotary_emb .long_inv_freq = long_inv_freq
16861696 self ._model .model .rotary_emb ._orig_forward = self ._model .model .rotary_emb .forward
16871697 self ._model .model .rotary_emb .forward = types .MethodType (long_rope , self ._model .model .rotary_emb )
@@ -1690,7 +1700,6 @@ def __enter__(self):
16901700 ):
16911701 self ._model .config .max_position_embeddings = self ._model .config .original_max_position_embeddings
16921702
1693-
16941703 def __exit__ (self , exc_type , exc_value , traceback ):
16951704 super ().__exit__ (exc_type , exc_value , traceback )
16961705 if hasattr (self ._model .model , "_orig_forward" ):
0 commit comments