@@ -203,6 +203,7 @@ def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tenso
203
203
condense_ratio = self .config .rope_condense_ratio ,
204
204
base = self .config .rope_base ,
205
205
extra_config = extra_config ,
206
+ rope_local_base_freq = self .config .rope_local_base_freq ,
206
207
)
207
208
208
209
def set_kv_cache (
@@ -567,6 +568,7 @@ def build_rope_cache(
567
568
base : int = 10000 ,
568
569
condense_ratio : int = 1 ,
569
570
extra_config : Optional [dict ] = None ,
571
+ rope_local_base_freq : Optional [float ] = None ,
570
572
) -> Tuple [torch .Tensor , torch .Tensor ]:
571
573
"""
572
574
Enhanced Transformer with Rotary Position Embedding.
@@ -620,6 +622,17 @@ def build_rope_cache(
620
622
if idx_theta .shape [- 1 ] > n_elem > 1 :
621
623
idx_theta = idx_theta [..., :n_elem ]
622
624
625
+ # if rope_local_base_freq is given, have a separate rope value for local embedding
626
+ # For now, we use default RoPE for local embedding
627
+ if rope_local_base_freq is not None :
628
+ local_theta = 1.0 / (rope_local_base_freq ** (torch .arange (0 , n_elem , 2 , device = device ).float () / n_elem ))
629
+ local_idx_theta = torch .outer (seq_idx , local_theta )
630
+ local_idx_theta = local_idx_theta .repeat (1 , 2 )
631
+ if local_idx_theta .shape [- 1 ] > n_elem > 1 :
632
+ local_idx_theta = local_idx_theta [..., :n_elem ]
633
+
634
+ idx_theta = torch .stack ((idx_theta , local_idx_theta ), dim = - 1 )
635
+
623
636
return torch .cos (idx_theta ), torch .sin (idx_theta )
624
637
625
638
0 commit comments