Skip to content

Commit 45d7ca9

Browse files
feat: add local base freq for rope (#1993)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2c1ec4c commit 45d7ca9

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

litgpt/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ class Config:
8181
scale_embeddings: bool = False
8282
lm_head_bias: bool = False
8383
final_logit_softcapping: Optional[float] = None
84+
# The base period of the RoPE embeddings for local attention.
85+
# If not provided, rope_theta will be used for both local and global attention.
86+
rope_local_base_freq: Optional[float] = None
8487

8588
def __post_init__(self):
8689
if not self.name:

litgpt/model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tenso
203203
condense_ratio=self.config.rope_condense_ratio,
204204
base=self.config.rope_base,
205205
extra_config=extra_config,
206+
rope_local_base_freq=self.config.rope_local_base_freq,
206207
)
207208

208209
def set_kv_cache(
@@ -567,6 +568,7 @@ def build_rope_cache(
567568
base: int = 10000,
568569
condense_ratio: int = 1,
569570
extra_config: Optional[dict] = None,
571+
rope_local_base_freq: Optional[float] = None,
570572
) -> Tuple[torch.Tensor, torch.Tensor]:
571573
"""
572574
Enhanced Transformer with Rotary Position Embedding.
@@ -620,6 +622,17 @@ def build_rope_cache(
620622
if idx_theta.shape[-1] > n_elem > 1:
621623
idx_theta = idx_theta[..., :n_elem]
622624

625+
# if rope_local_base_freq is given, have a separate rope value for local embedding
626+
# For now, we use default RoPE for local embedding
627+
if rope_local_base_freq is not None:
628+
local_theta = 1.0 / (rope_local_base_freq ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
629+
local_idx_theta = torch.outer(seq_idx, local_theta)
630+
local_idx_theta = local_idx_theta.repeat(1, 2)
631+
if local_idx_theta.shape[-1] > n_elem > 1:
632+
local_idx_theta = local_idx_theta[..., :n_elem]
633+
634+
idx_theta = torch.stack((idx_theta, local_idx_theta), dim=-1)
635+
623636
return torch.cos(idx_theta), torch.sin(idx_theta)
624637

625638

0 commit comments

Comments
 (0)