Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 62 additions & 12 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,12 @@ def forward(
if not isinstance(self.kv_cache, KVCache):
raise TypeError("You need to call `gpt.set_kv_cache()`")
k, v = self.kv_cache(input_pos, k, v)

if self.apply_sliding_window_attention:
actual_kv_len = k.size(2)
if mask is not None and mask.size(-1) != actual_kv_len:
mask = mask[..., :actual_kv_len]

if input_pos_maxp1 is not None:
# Subselect along sequence dimension
k = k[..., :input_pos_maxp1, :]
Expand All @@ -479,13 +485,22 @@ def forward(
│ True True True True │ │ False False True True │ │ False False True True │
└────────────────────────┘ └───────────────────────┘ └─────────────────────────┘
"""
if mask is None:
mask = torch.ones(T, T, dtype=q.dtype, device=q.device).triu(diagonal=1)
mask.masked_fill_(mask.bool(), float("-inf"))
mask = mask.view(1, 1, *mask.shape)
sliding_window_bias = torch.ones_like(mask).tril(diagonal=-self.config.sliding_window_size)
sliding_window_bias.masked_fill_(sliding_window_bias.bool(), float("-inf"))
mask += sliding_window_bias
if input_pos is not None:
# FIX: During inference with KV cache, sliding window is handled by cache size
# Only need causal mask, which is already applied by default
pass
else:
if mask is None:
mask = torch.ones(T, T, dtype=q.dtype, device=q.device).triu(diagonal=1)
mask.masked_fill_(mask.bool(), float("-inf"))
mask = mask.view(1, 1, *mask.shape)
sliding_window_bias = torch.zeros_like(mask)
for i in range(T):
window_start = max(0, i - self.config.sliding_window_size)
if window_start > 0:
sliding_window_bias[:, :, i, :window_start] = float("-inf")

mask = mask + sliding_window_bias

# Efficient attention using Flash Attention CUDA kernels.
# NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled.
Expand Down Expand Up @@ -527,7 +542,14 @@ def build_kv_cache(
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> "KVCache":
v_shape = (batch_size, self.config.n_query_groups, max_seq_length, self.config.head_size)
# FIX: For sliding window layers, limit cache to window size
if self.apply_sliding_window_attention and self.config.sliding_window_size is not None:
effective_cache_size = min(max_seq_length, self.config.sliding_window_size)
else:
effective_cache_size = max_seq_length

v_shape = (batch_size, self.config.n_query_groups, effective_cache_size, self.config.head_size)

if rope_cache_length is None:
if self.config.rotary_percentage != 1.0:
raise TypeError("Please pass the `rope_cache_length=gpt.cos.size(-1)` value")
Expand All @@ -536,10 +558,19 @@ def build_kv_cache(
k_shape = (
batch_size,
self.config.n_query_groups,
max_seq_length,
effective_cache_size, # Changed from max_seq_length
rope_cache_length + self.config.head_size - self.config.rope_n_elem,
)
return KVCache(k_shape, v_shape, device=device, dtype=dtype)

# Pass sliding window info to KVCache
return KVCache(
k_shape,
v_shape,
device=device,
dtype=dtype,
is_sliding_window=self.apply_sliding_window_attention,
sliding_window_size=self.config.sliding_window_size if self.apply_sliding_window_attention else None,
)

def _load_from_state_dict(self, state_dict: dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with legacy checkpoints."""
Expand Down Expand Up @@ -951,10 +982,15 @@ def __init__(
v_shape: Tuple[int, int, int, int],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
is_sliding_window: bool = False,
sliding_window_size: Optional[int] = None,
) -> None:
super().__init__()
self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False)
self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False)
self.is_sliding_window = is_sliding_window
self.sliding_window_size = sliding_window_size
self.max_cache_len = k_shape[2]

def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Expand All @@ -979,8 +1015,22 @@ def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) ->
self.v = self.v.to(v.dtype)
# update the cache
bs = k.size(0)
k = batched_index_copy_(self.k[:bs, ...], -2, input_pos, k)
v = batched_index_copy_(self.v[:bs, ...], -2, input_pos, v)
if self.is_sliding_window:
# FIX: Implement circular buffer for sliding window
# Map absolute positions to circular buffer indices
cache_positions = input_pos % self.max_cache_len
k = batched_index_copy_(self.k[:bs, ...], -2, cache_positions, k)
v = batched_index_copy_(self.v[:bs, ...], -2, cache_positions, v)

max_pos = input_pos.max().item()
if max_pos < self.max_cache_len:
k = k[:, :, : max_pos + 1, :]
v = v[:, :, : max_pos + 1, :]
else:
# Standard KV cache (global attention)
k = batched_index_copy_(self.k[:bs, ...], -2, input_pos, k)
v = batched_index_copy_(self.v[:bs, ...], -2, input_pos, v)

return k, v

def reset_parameters(self) -> None:
Expand Down
Loading