diff --git a/litgpt/model.py b/litgpt/model.py index 01ea83ad4a..936c87b895 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -453,6 +453,12 @@ def forward( if not isinstance(self.kv_cache, KVCache): raise TypeError("You need to call `gpt.set_kv_cache()`") k, v = self.kv_cache(input_pos, k, v) + + if self.apply_sliding_window_attention: + actual_kv_len = k.size(2) + if mask is not None and mask.size(-1) != actual_kv_len: + mask = mask[..., :actual_kv_len] + if input_pos_maxp1 is not None: # Subselect along sequence dimension k = k[..., :input_pos_maxp1, :] @@ -479,13 +485,22 @@ def forward( │ True True True True │ │ False False True True │ │ False False True True │ └────────────────────────┘ └───────────────────────┘ └─────────────────────────┘ """ - if mask is None: - mask = torch.ones(T, T, dtype=q.dtype, device=q.device).triu(diagonal=1) - mask.masked_fill_(mask.bool(), float("-inf")) - mask = mask.view(1, 1, *mask.shape) - sliding_window_bias = torch.ones_like(mask).tril(diagonal=-self.config.sliding_window_size) - sliding_window_bias.masked_fill_(sliding_window_bias.bool(), float("-inf")) - mask += sliding_window_bias + if input_pos is not None: + # FIX: During inference with KV cache, sliding window is handled by cache size + # Only need causal mask, which is already applied by default + pass + else: + if mask is None: + mask = torch.ones(T, T, dtype=q.dtype, device=q.device).triu(diagonal=1) + mask.masked_fill_(mask.bool(), float("-inf")) + mask = mask.view(1, 1, *mask.shape) + sliding_window_bias = torch.zeros_like(mask) + for i in range(T): + window_start = max(0, i - self.config.sliding_window_size) + if window_start > 0: + sliding_window_bias[:, :, i, :window_start] = float("-inf") + + mask = mask + sliding_window_bias # Efficient attention using Flash Attention CUDA kernels. # NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled. @@ -527,7 +542,14 @@ def build_kv_cache( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> "KVCache": - v_shape = (batch_size, self.config.n_query_groups, max_seq_length, self.config.head_size) + # FIX: For sliding window layers, limit cache to window size + if self.apply_sliding_window_attention and self.config.sliding_window_size is not None: + effective_cache_size = min(max_seq_length, self.config.sliding_window_size) + else: + effective_cache_size = max_seq_length + + v_shape = (batch_size, self.config.n_query_groups, effective_cache_size, self.config.head_size) + if rope_cache_length is None: if self.config.rotary_percentage != 1.0: raise TypeError("Please pass the `rope_cache_length=gpt.cos.size(-1)` value") @@ -536,10 +558,19 @@ def build_kv_cache( k_shape = ( batch_size, self.config.n_query_groups, - max_seq_length, + effective_cache_size, # Changed from max_seq_length rope_cache_length + self.config.head_size - self.config.rope_n_elem, ) - return KVCache(k_shape, v_shape, device=device, dtype=dtype) + + # Pass sliding window info to KVCache + return KVCache( + k_shape, + v_shape, + device=device, + dtype=dtype, + is_sliding_window=self.apply_sliding_window_attention, + sliding_window_size=self.config.sliding_window_size if self.apply_sliding_window_attention else None, + ) def _load_from_state_dict(self, state_dict: dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with legacy checkpoints.""" @@ -951,10 +982,15 @@ def __init__( v_shape: Tuple[int, int, int, int], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + is_sliding_window: bool = False, + sliding_window_size: Optional[int] = None, ) -> None: super().__init__() self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False) self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False) + self.is_sliding_window = is_sliding_window + self.sliding_window_size = sliding_window_size + self.max_cache_len = k_shape[2] def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -979,8 +1015,22 @@ def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> self.v = self.v.to(v.dtype) # update the cache bs = k.size(0) - k = batched_index_copy_(self.k[:bs, ...], -2, input_pos, k) - v = batched_index_copy_(self.v[:bs, ...], -2, input_pos, v) + if self.is_sliding_window: + # FIX: Implement circular buffer for sliding window + # Map absolute positions to circular buffer indices + cache_positions = input_pos % self.max_cache_len + k = batched_index_copy_(self.k[:bs, ...], -2, cache_positions, k) + v = batched_index_copy_(self.v[:bs, ...], -2, cache_positions, v) + + max_pos = input_pos.max().item() + if max_pos < self.max_cache_len: + k = k[:, :, : max_pos + 1, :] + v = v[:, :, : max_pos + 1, :] + else: + # Standard KV cache (global attention) + k = batched_index_copy_(self.k[:bs, ...], -2, input_pos, k) + v = batched_index_copy_(self.v[:bs, ...], -2, input_pos, v) + return k, v def reset_parameters(self) -> None: