From 5b994c3183f35458731421d58838fd38654c845d Mon Sep 17 00:00:00 2001 From: Matthias Seeger Date: Thu, 2 Jan 2025 15:44:51 +0100 Subject: [PATCH 1/2] Support for advanced KV caching and batch generation --- litgpt/adapter.py | 109 ++-- litgpt/adapter_v2.py | 33 +- litgpt/api.py | 28 +- litgpt/attention.py | 385 ++++++++++++++ litgpt/chat/base.py | 68 ++- litgpt/config.py | 62 ++- litgpt/finetune/adapter.py | 6 +- litgpt/finetune/adapter_v2.py | 6 +- litgpt/finetune/full.py | 6 +- litgpt/finetune/lora.py | 6 +- litgpt/generate/adapter.py | 15 +- litgpt/generate/adapter_v2.py | 15 +- litgpt/generate/base.py | 221 +++++--- litgpt/generate/full.py | 17 +- litgpt/generate/sequentially.py | 31 +- litgpt/generate/tp.py | 20 +- litgpt/kvcache/__init__.py | 14 + litgpt/kvcache/base.py | 456 ++++++++++++++++ litgpt/kvcache/baselines.py | 356 +++++++++++++ litgpt/kvcache/test_utils.py | 77 +++ litgpt/kvcache/utils.py | 17 + litgpt/lora.py | 27 +- litgpt/model.py | 795 ++++++++++++++++------------ litgpt/utils.py | 19 + tests/generate/test_adapter.py | 1 + tests/generate/test_main.py | 154 +++++- tests/generate/test_sequentially.py | 1 + tests/kvcache/test_base.py | 67 +++ tests/kvcache/test_generic.py | 152 ++++++ tests/test_adapter.py | 3 +- tests/test_adapter_v2.py | 3 +- tests/test_batch.py | 85 +-- tests/test_chat.py | 3 + tests/test_model.py | 275 +++++----- 34 files changed, 2812 insertions(+), 721 deletions(-) create mode 100644 litgpt/attention.py create mode 100644 litgpt/kvcache/__init__.py create mode 100644 litgpt/kvcache/base.py create mode 100644 litgpt/kvcache/baselines.py create mode 100644 litgpt/kvcache/test_utils.py create mode 100644 litgpt/kvcache/utils.py create mode 100644 tests/kvcache/test_base.py create mode 100644 tests/kvcache/test_generic.py diff --git a/litgpt/adapter.py b/litgpt/adapter.py index 5297df4eb3..6b505ca988 100644 --- a/litgpt/adapter.py +++ b/litgpt/adapter.py @@ -15,7 +15,9 @@ import torch.nn as nn from typing_extensions import Self +from litgpt.attention import DefaultKeysAndValues, MultiHeadSelfAttention from litgpt.config import Config as BaseConfig +from litgpt.kvcache.base import KVCache from litgpt.model import GPT as BaseModel from litgpt.model import Block as BaseBlock from litgpt.model import CausalSelfAttention as BaseCausalSelfAttention @@ -34,7 +36,11 @@ def __init__(self, config: Config) -> None: assert config.padded_vocab_size is not None self.config = config - self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) + self.lm_head = nn.Linear( + config.n_embd, + config.padded_vocab_size, + bias=config.lm_head_bias, + ) self.transformer = nn.ModuleDict( dict( wte=nn.Embedding(config.padded_vocab_size, config.n_embd), @@ -42,8 +48,11 @@ def __init__(self, config: Config) -> None: ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), ) ) - self.mask_cache: Optional[torch.Tensor] = None + self.mha = MultiHeadSelfAttention(config) self.max_seq_length = self.config.block_size + self._start_of_layer_hook = config.start_of_layer_hook + # Have dense KV caches been created by `set_kv_cache`? + self._default_kv_cache = False @classmethod def from_name(cls, name: str, **kwargs: Any) -> Self: @@ -57,18 +66,33 @@ def _init_weights(self, module: nn.Module) -> None: class Block(BaseBlock): - def __init__(self, config: Config, block_idx: int) -> None: - super().__init__(config, block_idx) - self.attn = CausalSelfAttention(config, block_idx) + def __init__( + self, + config: Config, + block_idx: int, + kv_cache: Optional[KVCache] = None, + ) -> None: + super().__init__(config, block_idx, kv_cache) + self.attn = CausalSelfAttention(config, block_idx, kv_cache=kv_cache) class CausalSelfAttention(BaseCausalSelfAttention): """A modification of `litgpt.model.CausalSelfAttention` that adds the attention over the adaption prompt.""" - def __init__(self, config: Config, block_idx: int) -> None: - super().__init__(config, block_idx) - if block_idx >= config.adapter_start_layer: + def __init__( + self, + config: Config, + block_idx: int, + kv_cache: Optional[KVCache] = None, + ) -> None: + super().__init__( + config=config, + block_idx=block_idx, + kv_cache=kv_cache, + ) + self._extend_forward = block_idx >= config.adapter_start_layer + if self._extend_forward: # adapter embedding layer self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd) # gate for adaption @@ -76,37 +100,46 @@ def __init__(self, config: Config, block_idx: int) -> None: # kv cache for inference self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - def scaled_dot_product_attention( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None + def _transform_output( + self, + y: torch.Tensor, + query: torch.Tensor, + mha: MultiHeadSelfAttention, ) -> torch.Tensor: - y = super().scaled_dot_product_attention(q, k, v, mask) - if self.block_idx < self.config.adapter_start_layer: - return y - - aT = self.config.adapter_prompt_length - if self.adapter_kv_cache is not None: - # since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av - # are the same every call - ak, av = self.adapter_kv_cache - else: - prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd) - aqkv = self.qkv(prefix) - q_per_kv = self.config.n_head // self.config.n_query_groups - aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size) - aqkv = aqkv.permute(0, 2, 3, 1, 4) - _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2) - if self.config.n_query_groups != 1: - # for MHA this is a no-op - ak = ak.repeat_interleave(q_per_kv, dim=2) - av = av.repeat_interleave(q_per_kv, dim=2) - ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs) - av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs) - self.adapter_kv_cache = (ak, av) - - T = q.size(2) - amask = torch.ones(T, aT, dtype=torch.bool, device=q.device) - ay = super().scaled_dot_product_attention(q, ak, av, amask) - return y + self.gating_factor * ay + if self._extend_forward: + B, T, _ = y.shape + y = y.view(B, T, self.config.n_head, self.config.head_size) + aT = self.config.adapter_prompt_length + if self.adapter_kv_cache is not None: + # since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av + # are the same every call + ak, av = self.adapter_kv_cache + else: + prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd) + aqkv = self.qkv(prefix) + q_per_kv = self.config.n_head // self.config.n_query_groups + aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size) + aqkv = aqkv.permute(0, 2, 3, 1, 4) + _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2) + if self.config.n_query_groups != 1: + # for MHA this is a no-op + ak = ak.repeat_interleave(q_per_kv, dim=2) + av = av.repeat_interleave(q_per_kv, dim=2) + ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs) + av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs) + self.adapter_kv_cache = (ak, av) + + amask = torch.ones(T, aT, dtype=torch.bool, device=query.device) + a_k_and_v = DefaultKeysAndValues(keys=ak, values=av) + ay, _ = mha.scaled_dot_product_attention( + query=query, + k_and_v=a_k_and_v, + mask=amask, + is_causal=False, + ) + y = (y + self.gating_factor * ay).view(B, T, -1) + + return y def reset_parameters(self) -> None: if hasattr(self, "gating_factor"): diff --git a/litgpt/adapter_v2.py b/litgpt/adapter_v2.py index 363d6eb2b1..ff26b3c402 100644 --- a/litgpt/adapter_v2.py +++ b/litgpt/adapter_v2.py @@ -19,6 +19,8 @@ from litgpt.adapter import GPT as BaseModel from litgpt.adapter import CausalSelfAttention as BaseCausalSelfAttention from litgpt.adapter import Config as BaseConfig +from litgpt.attention import MultiHeadSelfAttention +from litgpt.kvcache.base import KVCache from litgpt.model import Block as BaseBlock from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble from litgpt.utils import map_old_state_dict_weights @@ -69,7 +71,11 @@ def __init__(self, config: Config) -> None: assert config.padded_vocab_size is not None self.config = config - self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) + self.lm_head = AdapterV2Linear( + config.n_embd, + config.padded_vocab_size, + bias=config.lm_head_bias, + ) self.transformer = nn.ModuleDict( dict( wte=nn.Embedding(config.padded_vocab_size, config.n_embd), @@ -77,8 +83,11 @@ def __init__(self, config: Config) -> None: ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), ) ) - self.mask_cache: Optional[torch.Tensor] = None + self.mha = MultiHeadSelfAttention(config) self.max_seq_length = self.config.block_size + self._start_of_layer_hook = config.start_of_layer_hook + # Have dense KV caches been created by `set_kv_cache`? + self._default_kv_cache = False @classmethod def from_name(cls, name: str, **kwargs: Any) -> Self: @@ -98,9 +107,14 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa class Block(BaseBlock): - def __init__(self, config: Config, block_idx: int) -> None: - super().__init__(config, block_idx) - self.attn = CausalSelfAttention(config, block_idx) + def __init__( + self, + config: Config, + block_idx: int, + kv_cache: Optional[KVCache] = None, + ) -> None: + super().__init__(config, block_idx, kv_cache) + self.attn = CausalSelfAttention(config, block_idx, kv_cache=kv_cache) self.mlp = config.mlp_class(config) @@ -108,8 +122,13 @@ class CausalSelfAttention(BaseCausalSelfAttention): """A modification of `litgpt.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class""" # Copy&paste from :class:`model.CausalSelfAttention` - def __init__(self, config: Config, block_idx: int) -> None: - super().__init__(config, block_idx) + def __init__( + self, + config: Config, + block_idx: int, + kv_cache: Optional[KVCache] = None, + ) -> None: + super().__init__(config, block_idx, kv_cache) # key, query, value projections for all heads, but in a batch shape = (config.n_head + 2 * config.n_query_groups) * config.head_size self.qkv = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias) diff --git a/litgpt/api.py b/litgpt/api.py index 32cc196603..76c35cc92c 100644 --- a/litgpt/api.py +++ b/litgpt/api.py @@ -383,7 +383,11 @@ def distribute( kv_cache_size = model.max_seq_length else: kv_cache_size = fixed_kv_cache_size - model.set_kv_cache(batch_size=1, max_seq_length=kv_cache_size, device=fabric.device) + model.set_kv_cache( + batch_size=1, + max_seq_length=kv_cache_size, + device=fabric.device, + ) self.kv_cache_initialized = True self.fixed_kv_cache_size = fixed_kv_cache_size @@ -508,20 +512,26 @@ def generate( prompt_length = input_ids.size(0) max_returned_tokens = prompt_length + max_new_tokens + if self.fabric is not None: + device = self.fabric.device + else: + device = self.preprocessor.device if not self.kv_cache_initialized: - if self.fabric is not None: - device = self.fabric.device - else: - device = self.preprocessor.device - self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=device) + self.model.set_kv_cache( + batch_size=1, + max_seq_length=max_returned_tokens, + device=device, + ) self.kv_cache_initialized = True # Dynamically grow the kv cache size if necessary if not self.fixed_kv_cache_size and self.prev_generated_seq_length < max_returned_tokens: - tmp_device = self.model.mask_cache.device self.model.clear_kv_cache() - self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=tmp_device) - + self.model.set_kv_cache( + batch_size=1, + max_seq_length=max_returned_tokens, + device=device, + ) else: for block in self.model.transformer.h: block.attn.kv_cache.reset_parameters() diff --git a/litgpt/attention.py b/litgpt/attention.py new file mode 100644 index 0000000000..6315cedca2 --- /dev/null +++ b/litgpt/attention.py @@ -0,0 +1,385 @@ +import math +from typing import Optional, Tuple + +import torch +from torch.nn import functional as F + +from litgpt.config import Config + + +class KeysAndValues: + """ + Object passed to :meth:`MultiHeadSelfAttention.__call__`. Allows to access + keys or values, but (in general) not both at the same time. Implementations + may use the same buffer to return them in the methods below. + + However, if :meth:`both_in_parallel` returns `True`, the tensors returned + by :meth:`keys` and :meth:`values` may be used in parallel, since they are + supported by separate buffers. + + """ + + def keys(self) -> torch.Tensor: + """ + Returns: + keys tensor, shape `(eff_batch_size, n_query_groups, T, head_size)`, + where `T <= cache_length` is the current cache length) + + """ + raise NotImplementedError() + + def values(self) -> torch.Tensor: + """ + Returns: + values tensor, shape `(eff_batch_size, n_query_groups, T, head_size)`, + where `T <= cache_length` is the current cache length) + + """ + raise NotImplementedError() + + def both_in_parallel(self) -> bool: + """ + Returns: + Can use both `keys` and `values` in parallel? Otherwise, can only + use one of them at the same time + """ + return False + + +class DefaultKeysAndValues(KeysAndValues): + def __init__(self, keys: torch.Tensor, values: torch.Tensor): + # The final dimension of K and V can be different (in general) + assert keys.shape[:-1] == values.shape[:-1] and keys.ndim == 4, (keys.shape, values.shape) + self._keys = keys + self._values = values + + def keys(self) -> torch.Tensor: + return self._keys + + def values(self) -> torch.Tensor: + return self._values + + def both_in_parallel(self) -> bool: + """ + Keys and values are supported by different buffers, so they can be + used at the same time. + + """ + return True + + +class MultiHeadSelfAttention: + """ + Maintains code for the inner part of multi-head self-attention which is not + parameterized. This is used both by :class:`CausalSelfAttention` and by the + default KV cache implementation :class:`DefaultKVCache`. + + """ + + def __init__(self, config: Config) -> None: + self.config = config + + def set_seq_length( + self, + value: int, + device: torch.device, + ) -> None: + pass # Currently, we don't use this + + def __call__( + self, + query: torch.Tensor, + k_and_v: KeysAndValues, + block_idx: int, + input_pos: Optional[int] = None, + return_attn_weights: bool = False, + token_positions: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + + Args: + query: Queries, shape `(batch_size, n_heads, q_len, head_size)` + k_and_v: Access to keys and values, shape + (batch_size, n_query_groups, kv_len, head_size)` + block_idx: Index of block (or layer) in model + input_pos: Position in input sequence. Defaults to 0 + return_attn_weights: If this is `True` and `input_pos > 0`, the + attention weights (or scores) are returned as second argument + token_positions: Required if `input_pos > 0`. Contains token + positions in KV cache. This is needed to select the correct + part of the mask matrix + + Returns: + `attn_output, attn_weights`, where `attn_weights` is `None` if + attention weights are not returned. + + """ + # We need the attention mask if there is sliding window attention, + # or if `input_pos > 0` and T > 1. + for_prefill = input_pos == 0 + is_causal = input_pos is None or for_prefill + if not is_causal and token_positions is None: + raise ValueError("token_positions must be given if input_pos > 0") + apply_sliding_window_attention = ( + self.config.sliding_window_size is not None and self.config.sliding_window_indices[block_idx] == 1 + ) + B, _, T, _ = query.shape + use_mask = apply_sliding_window_attention or (not is_causal and T > 1) + mask = None + if use_mask: + # Special case requires building a mask + if is_causal: + mask = build_mask_cache( + max_seq_length=T, + sliding_window_size=self.config.sliding_window_size, + dtype=query.dtype, + device=query.device, + ).view(1, 1, T, T) + is_causal = False + else: + # We need a mask if T > 1, since inference needs to be causal + # for the new tokens + assert input_pos > 0 + mask = build_mask_slice( + input_pos=input_pos, + num=T, + token_positions=token_positions, + dtype=query.dtype, + device=query.device, + sliding_window_size=self.config.sliding_window_size, + ) + # mask has shape (B, n_query_groups, T, kv_len), must have + # shape (B, n_head, T, kv_len) + nh_q = self.config.n_head + nh_k = self.config.n_query_groups + q_per_kv = nh_q // nh_k + if q_per_kv > 1: + mask = mask.unsqueeze(2).expand(-1, -1, q_per_kv, -1, -1).reshape(B, nh_q, T, -1) + + # Efficient attention using Flash Attention CUDA kernels. + # NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled. + # ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs) + return_scores = not (input_pos is None or for_prefill) and return_attn_weights + y, scores = self.scaled_dot_product_attention( + query, + k_and_v, + mask, + is_causal, + return_scores, + ) + # Re-assemble all head outputs side by side. + y = y.reshape(B, T, -1) + return y, scores + + def scaled_dot_product_attention( + self, + query: torch.Tensor, + k_and_v: KeysAndValues, + mask: Optional[torch.Tensor] = None, + is_causal: bool = True, + return_scores: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + assert mask is None or not is_causal, "Cannot have mask and is_causal=True" + scale = 1.0 / math.sqrt(self.config.attention_scores_scalar or self.config.head_size) + + # We cannot call PyTorch scaled_dot_product_attention if: + # - Attention scores need to be returned; or + # - Logit softcapping is required; or + # - We cannot access keys and values from `k_and_v` in parallel (this + # never happens if `is_causal == True`) + if return_scores or self.config.attention_logit_softcapping is not None or not k_and_v.both_in_parallel(): + y, scores = scaled_dot_product_attention( + query=query, + k_and_v=k_and_v, + scale=scale, + mask=mask, + attention_logit_softcapping=self.config.attention_logit_softcapping, + is_causal=is_causal, + ) + if not return_scores: + scores = None + else: + # We need `key` and `value` at the same time here. For the training + # use case, this will be the case, since `k_and_v` is the default + # in this case. + key = k_and_v.keys() + value = k_and_v.values() + y = F.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=mask, + dropout_p=0.0, + scale=scale, + is_causal=is_causal, + enable_gqa=self.config.n_query_groups < self.config.n_head, + ) + scores = None + return y.transpose(1, 2), scores + + +def scaled_dot_product_attention( + query: torch.Tensor, + k_and_v: KeysAndValues, + scale: float, + mask: Optional[torch.Tensor] = None, + attention_logit_softcapping: Optional[float] = None, + is_causal: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + dtype = query.dtype + key = k_and_v.keys() + # Scale both `query` and `key` by `sqrt(scale)`, as is done in + # `torch.nn.functional.scaled_dot_product_attention` + sqrt_scale = math.sqrt(scale) + scores = _attention_compute_scores( + sqrt_scale * query, + sqrt_scale * key, + ) + scores = do_softcapping(scores, attention_logit_softcapping) + if mask is None and is_causal: + T = query.shape[2] + assert key.size(2) == T, "is_causal=True only if query, key have same size" + mask = torch.ones(T, T, dtype=dtype, device=query.device).triu(diagonal=1) + mask.masked_fill_(mask.bool(), torch.finfo(dtype).min) + mask = mask.view(1, 1, T, T) + if mask is not None: + scores = scores + mask + scores = F.softmax(scores, dim=-1, dtype=torch.float).to(dtype=dtype) + value = k_and_v.values() + return _attention_compute_weighted_values(scores, value), scores + + +def _attention_compute_scores( + query: torch.Tensor, + key: torch.Tensor, +) -> torch.Tensor: + assert query.ndim == key.ndim == 4 + assert query.shape[0] == key.shape[0] and query.shape[3] == key.shape[3] + nh_q = query.shape[1] + nh_k = key.shape[1] + assert nh_q % nh_k == 0 + # - query: (bs, nh_q, T_q, hs) + # - key: (bs, nh_k, T_k, hs) + q_per_kv = nh_q // nh_k + key_transposed = key.mT # (bs, nh_k, hs, T_k) + if q_per_kv == 1: + return query @ key_transposed + else: + assert q_per_kv > 1 + q_shape = query.shape[:1] + (nh_k, q_per_kv) + query.shape[2:] + _query = query.view(*q_shape) + key_transposed = key_transposed.unsqueeze(2) + # At this point: + # - _query: (bs, nh_k, q_per_kv, T_q, hs) + # - key_transposed: (bs, nh_k, 1, hs, T_k) + # - scores: (bs, nh_k, q_per_kv, T_q, T_k) + scores = torch.matmul(_query, key_transposed) + s_shape = query.shape[:-1] + (key.shape[2],) + return scores.view(*s_shape) + + +def _attention_compute_weighted_values( + scores: torch.Tensor, + value: torch.Tensor, +) -> torch.Tensor: + assert scores.ndim == value.ndim == 4 + assert scores.shape[0] == scores.shape[0] and scores.shape[3] == value.shape[2] + nh_q = scores.shape[1] + nh_k = value.shape[1] + assert nh_q % nh_k == 0 + # - scores: (bs, nh_q, T_q, T_k) + # - value: (bs, nh_k, T_k, hs) + q_per_kv = nh_q // nh_k + if q_per_kv == 1: + return scores @ value + else: + s_shape = scores.shape[:1] + (nh_k, q_per_kv) + scores.shape[2:] + _scores = scores.view(*s_shape) + _value = value.unsqueeze(2) + # At this point: + # - _scores: (bs, nh_k, q_per_kv, T_q, T_k) + # - _value: (bs, nh_k, 1, T_k, hs) + # - result: (bs, nh_k, q_per_kv, T_q, hs) + result = torch.matmul(_scores, _value) + r_shape = scores.shape[:-1] + (value.shape[-1],) + return result.view(*r_shape) + + +def build_mask_cache( + max_seq_length: int, + sliding_window_size: Optional[int], + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + """ + Global Window Sliding window Sliding window + attention mask + bias = attention mask + ┌────────────────────────┐ ┌───────────────────────┐ ┌─────────────────────────┐ + │ True False False False │ │ True True True True │ │ True False False False │ + │ True True False False │ │ True True True True │ │ True True False False │ + │ True True True False │ │ False True True True │ │ False True True False │ + │ True True True True │ │ False False True True │ │ False False True True │ + └────────────────────────┘ └───────────────────────┘ └─────────────────────────┘ + """ + # Usual causal mask: + mask = torch.ones( + max_seq_length, max_seq_length, device=device, dtype=dtype, + ).triu(diagonal=1) + if sliding_window_size is not None: + mask += torch.ones_like(mask).tril(diagonal=-sliding_window_size) + mask.masked_fill_(mask.bool(), torch.finfo(dtype).min) + return mask + + +def build_mask_slice( + input_pos: int, + num: int, + token_positions: torch.Tensor, + dtype: torch.dtype, + device: torch.device, + sliding_window_size: Optional[int] = None, +) -> torch.Tensor: + """ + Returns mask for case `input_pos > 0` in :class:`MultiHeadSelfAttention`. + + Args: + input_pos: Position in input sequence, must be positive + num: Length of query argument `q_len` + token_positions: Token positions in KV cache, shape + `(eff_batch_size, n_query_groups, cache_length)` + dtype: Data type of the output mask + device: Device of the output mask + sliding_window_size: Parameter for sliding window attention + + Returns: + Mask tensor, shape `(eff_batch_size, n_query_groups, num, cache_length)` + + """ + # Build boolean mask, then map False -> 0, True -> -infty + # If (i, j) indexes the complete (seq_len, seq_len) mask matrix, + # causality is given by I(i < j). If `sliding_window_size` is given, + # this translates to I(i >= j + sws) if sws = sliding_window_size. + assert token_positions.ndim == 3 + tp_dtype = token_positions.dtype + token_positions = token_positions.unsqueeze(2).to(device=device) + kwargs = dict(device=device, dtype=tp_dtype) + bool_mask = torch.arange( + input_pos, input_pos + num, **kwargs, + ).view(1, 1, -1, 1) < token_positions + if sliding_window_size is not None: + extra_mask = torch.arange( + input_pos - sliding_window_size, + input_pos + num - sliding_window_size, + **kwargs, + ).view(1, 1, -1, 1) >= token_positions + bool_mask += extra_mask + mask = torch.zeros(bool_mask.shape, dtype=dtype, device=device) + mask.masked_fill_(bool_mask, torch.finfo(dtype).min) + return mask + + +def do_softcapping(x: torch.Tensor, thresh: Optional[float]) -> torch.Tensor: + if thresh is not None: + return torch.tanh(x / thresh) * thresh + else: + return x diff --git a/litgpt/chat/base.py b/litgpt/chat/base.py index 123028b590..ad38f903a6 100644 --- a/litgpt/chat/base.py +++ b/litgpt/chat/base.py @@ -30,6 +30,7 @@ def generate( prompt: torch.Tensor, max_returned_tokens: int, *, + prompt_chunksize: int = 16, temperature: float = 1.0, top_k: Optional[int] = None, top_p: float = 1.0, @@ -62,20 +63,31 @@ def generate( from litgpt.generate.base import generate_fn return generate_fn( - include_prompt=False, - include_eos=False, model=model, prompt=prompt, max_returned_tokens=max_returned_tokens, + prompt_chunksize=prompt_chunksize, temperature=temperature, top_k=top_k, top_p=top_p, stop_tokens=stop_tokens, + include_prompt=False, + include_eos=False, ) def process_prompt( - prompt, model, tokenizer, prompt_style, fabric, temperature, max_new_tokens, top_k, top_p, stop_tokens + prompt: str, + model: GPT, + tokenizer, + prompt_style, + fabric, + max_new_tokens: int, + prompt_chunksize: int, + temperature: float, + top_k: Optional[int], + top_p: float, + stop_tokens: Tuple[List[int], ...], ): prompt = prompt_style.apply(prompt=prompt) encoded_prompt = tokenizer.encode(prompt, device=fabric.device) @@ -83,16 +95,16 @@ def process_prompt( if max_new_tokens is None: max_returned_tokens = model.max_seq_length else: - first_turn = model.mask_cache is None max_returned_tokens = encoded_prompt.size(0) + max_new_tokens - if first_turn or max_returned_tokens > model.max_seq_length: + msl = model.max_seq_length + if max_returned_tokens > msl or model.config.block_size == msl: model.max_seq_length = max_returned_tokens - model.set_kv_cache(batch_size=1, device=fabric.device) y: Iterator[torch.Tensor] = generate( - model, - encoded_prompt, - max_returned_tokens, + model=model, + prompt=encoded_prompt, + max_returned_tokens=max_returned_tokens, + prompt_chunksize=prompt_chunksize, temperature=temperature, top_k=top_k, top_p=top_p, @@ -111,8 +123,7 @@ def process_prompt( t = time.perf_counter() - t0 - for block in model.transformer.h: - block.attn.kv_cache.reset_parameters() + model.clear_kv_cache() fabric.print( f"\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec, {tokens_generated} tokens", file=sys.stderr, @@ -120,7 +131,19 @@ def process_prompt( fabric.print() -def interact(multiline, model, tokenizer, prompt_style, fabric, temperature, max_new_tokens, top_k, top_p, stop_tokens): +def interact( + multiline: bool, + model: GPT, + tokenizer, + prompt_style, + fabric, + max_new_tokens: int, + prompt_chunksize: int, + temperature: float, + top_k: Optional[int], + top_p: float, + stop_tokens: Tuple[List[int], ...], +): while True: try: if not multiline: @@ -143,7 +166,17 @@ def interact(multiline, model, tokenizer, prompt_style, fabric, temperature, max break process_prompt( - prompt, model, tokenizer, prompt_style, fabric, temperature, max_new_tokens, top_k, top_p, stop_tokens + prompt=prompt, + model=model, + tokenizer=tokenizer, + prompt_style=prompt_style, + fabric=fabric, + max_new_tokens=max_new_tokens, + prompt_chunksize=prompt_chunksize, + temperature=temperature, + top_k=top_k, + top_p=top_p, + stop_tokens=stop_tokens, ) @@ -152,6 +185,7 @@ def main( checkpoint_dir: Path, *, max_new_tokens: int = 50, + prompt_chunksize: int = 16, top_k: Optional[int] = 50, top_p: float = 1.0, temperature: float = 0.8, @@ -167,6 +201,11 @@ def main( checkpoint_dir: A local path to a directory containing the model weights or a valid model name. You can get a list of valid model names via the `litgpt download list` command line argument. max_new_tokens: The number of generation steps to take. + prompt_chunksize: If even the shortest prompt is longer than the KV + cache, prompts are processed in chunks of this size in the + prefill phase. Once the shortest has been processed to the + end, we proceed with chunk size 1. + Defaults to 1, but larger values are recommended for long prompts. top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens @@ -261,8 +300,9 @@ def main( tokenizer=tokenizer, prompt_style=prompt_style, fabric=fabric, - temperature=temperature, max_new_tokens=(None if compile else max_new_tokens), + prompt_chunksize=prompt_chunksize, + temperature=temperature, top_k=top_k, top_p=top_p, stop_tokens=stop_tokens, diff --git a/litgpt/config.py b/litgpt/config.py index 4f5205f832..a5cb2ebc5f 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass, field from pathlib import Path -from typing import Any, List, Literal, Optional, Type, Union +from typing import Any, Callable, List, Literal, Optional, Type, Union import torch import yaml @@ -22,6 +22,12 @@ def find_multiple(n: int, k: int) -> int: return n return n + k - (n % k) +# See `Config.start_of_layer_hook`. A start of layer hook is called just before +# a layer is computed. The call is `hook(x, block_idx, input_pos)`, where +# `x` is the layer input, `block_idx` the number of the layer, and `input_pos` +# the position in the sequence (see :meth:`GPT.forward`). +StartOfLayerHook = Callable[[torch.Tensor, int, Optional[int]], None] + @dataclass class Config: @@ -68,8 +74,13 @@ class Config: n_query_groups: Optional[int] = None attn_bias: bool = False attention_scores_scalar: Optional[int] = None + # If `sliding_window_size` is given, sliding window attention with this + # size is used in layers where `sliding_window_indices` has a 1. The + # default is all 1, so that sliding window attention is used in all + # layers. If `len(sliding_window_indices) > n_layer`, we only use the + # initial part. sliding_window_size: Optional[int] = None - sliding_window_indices: Optional[List] = None + sliding_window_indices: Optional[List[int]] = None # if `attention_logit_softcapping` is used, cannot use optimized # `torch.nn.functional.scaled_dot_product_attention` (which implements # Flash attention), may result in higher memory and runtime footprint. @@ -92,9 +103,17 @@ class Config: lm_head_bias: bool = False final_logit_softcapping: Optional[float] = None # The base period of the RoPE embeddings for local attention. - # If not provided, rope_theta will be used for both local and global attention. + # If not provided, `rope_base` will be used for both local and global attention. rope_local_base_freq: Optional[float] = None - rope_indices: Optional[List] = None + # If provided, must have `>= n_layer` entries, either 0 or 1. For 0, + # `rope_base` is used, for 1 `rope_local_base_freq` is used. If + # `len(rope_indices) > n_layer`, we only use the initial part. + rope_indices: Optional[List[int]] = None + # This hook is called in `GPT.forward` at the start of each layer, + # passing the (detached) layer input, the layer index, and `input_pos`. + # It is also called with the final layer output (which is the input + # into the head block), passing `n_layer` as second argument. + start_of_layer_hook: Optional[StartOfLayerHook] = None def __post_init__(self): if not self.name: @@ -125,11 +144,19 @@ def __post_init__(self): self.rope_n_elem = int(self.rotary_percentage * self.head_size) - if self.sliding_window_size is not None and self.sliding_window_indices is None: - self.sliding_window_indices = [1] * self.n_layer + if self.sliding_window_size is not None: + self.sliding_window_indices = check_indicator_and_length( + self.sliding_window_indices, + name="sliding_window_indices", + required_length=self.n_layer, + ) - if self.rope_local_base_freq is not None and self.rope_indices is None: - self.rope_indices = [1] * self.n_layer + if self.rope_local_base_freq is not None: + self.rope_indices = check_indicator_and_length( + self.rope_indices, + name="rope_indices", + required_length=self.n_layer, + ) @classmethod def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]: @@ -196,6 +223,25 @@ def norm_class(self) -> Type: return getattr(torch.nn, self.norm_class_name) +def check_indicator_and_length( + params: Optional[List[int]], + name: str, + required_length: int, + use_initial_part: bool = True, + def_val: int = 1, +) -> List[int]: + if params is None: + return [def_val] * required_length + if len(params) != required_length: + if use_initial_part and len(params) > required_length: + params = params[:required_length] + else: + raise ValueError(f"{name} = {params}, must have length {required_length}") + if not set(params).issubset({0, 1}): + raise ValueError(f"{name} = {params}, must only contain 0 and 1") + return params + + ######################## # Stability AI StableLM ######################## diff --git a/litgpt/finetune/adapter.py b/litgpt/finetune/adapter.py index 42479baba5..6c1c95e680 100644 --- a/litgpt/finetune/adapter.py +++ b/litgpt/finetune/adapter.py @@ -426,7 +426,11 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E # do not set `max_seq_length=max_returned_token` because memory is not a concern here model.set_kv_cache(batch_size=1) output = generate( - model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id + model=model, + prompt=encoded, + max_returned_tokens=max_returned_tokens, + temperature=0.8, + eos_id=tokenizer.eos_id, ) model.clear_kv_cache() model.train() diff --git a/litgpt/finetune/adapter_v2.py b/litgpt/finetune/adapter_v2.py index c1ba67521c..f6b3ddc96f 100644 --- a/litgpt/finetune/adapter_v2.py +++ b/litgpt/finetune/adapter_v2.py @@ -449,7 +449,11 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E # do not set `max_seq_length=max_returned_token` because memory is not a concern here model.set_kv_cache(batch_size=1) output = generate( - model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id + model=model, + prompt=encoded, + max_returned_tokens=max_returned_tokens, + temperature=0.8, + eos_id=tokenizer.eos_id, ) model.clear_kv_cache() model.train() diff --git a/litgpt/finetune/full.py b/litgpt/finetune/full.py index 22699b8c5c..6a670f1c42 100644 --- a/litgpt/finetune/full.py +++ b/litgpt/finetune/full.py @@ -397,7 +397,11 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E # do not set `max_seq_length=max_returned_token` because memory is not a concern here model.set_kv_cache(batch_size=1) output = generate( - model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id + model=model, + prompt=encoded, + max_returned_tokens=max_returned_tokens, + temperature=0.8, + eos_id=tokenizer.eos_id, ) model.clear_kv_cache() model.train() diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index 9593e1d4fe..928b34cb8c 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -455,7 +455,11 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E # do not set `max_seq_length=max_returned_token` because memory is not a concern here model.set_kv_cache(batch_size=1) output = generate( - model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id + model=model, + prompt=encoded, + max_returned_tokens=max_returned_tokens, + temperature=0.8, + eos_id=tokenizer.eos_id, ) model.clear_kv_cache() model.train() diff --git a/litgpt/generate/adapter.py b/litgpt/generate/adapter.py index fb7f75c5ba..7cc0d949ce 100644 --- a/litgpt/generate/adapter.py +++ b/litgpt/generate/adapter.py @@ -33,6 +33,7 @@ def main( adapter_path: Path = Path("out/finetune/adapter/final/lit_model.pth.adapter"), quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None, max_new_tokens: int = 100, + prompt_chunksize: int = 16, top_k: Optional[int] = 50, top_p: float = 1.0, temperature: float = 0.8, @@ -55,6 +56,11 @@ def main( - bnb.int8: 8-bit quantization from bitsandbytes for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md max_new_tokens: The number of generation steps to take. + prompt_chunksize: If even the shortest prompt is longer than the KV + cache, prompts are processed in chunks of this size in the + prefill phase. Once the shortest has been processed to the + end, we proceed with chunk size 1. + Defaults to 1, but larger values are recommended for long prompts. top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens @@ -134,7 +140,14 @@ def main( L.seed_everything(1234) t0 = time.perf_counter() y = generate( - model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id + model=model, + prompt=encoded, + max_returned_tokens=max_returned_tokens, + prompt_chunksize=prompt_chunksize, + temperature=temperature, + top_k=top_k, + top_p=top_p, + eos_id=tokenizer.eos_id, ) t = time.perf_counter() - t0 diff --git a/litgpt/generate/adapter_v2.py b/litgpt/generate/adapter_v2.py index e7a65fa528..1bfcf5ab38 100644 --- a/litgpt/generate/adapter_v2.py +++ b/litgpt/generate/adapter_v2.py @@ -33,6 +33,7 @@ def main( adapter_path: Path = Path("out/finetune/adapter-v2/final/lit_model.pth.adapter_v2"), quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None, max_new_tokens: int = 100, + prompt_chunksize: int = 16, top_k: Optional[int] = 50, top_p: float = 1.0, temperature: float = 0.8, @@ -55,6 +56,11 @@ def main( - bnb.int8: 8-bit quantization from bitsandbytes for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md max_new_tokens: The number of generation steps to take. + prompt_chunksize: If even the shortest prompt is longer than the KV + cache, prompts are processed in chunks of this size in the + prefill phase. Once the shortest has been processed to the + end, we proceed with chunk size 1. + Defaults to 1, but larger values are recommended for long prompts. top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens @@ -134,7 +140,14 @@ def main( L.seed_everything(1234) t0 = time.perf_counter() y = generate( - model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id + model=model, + prompt=encoded, + max_returned_tokens=max_returned_tokens, + prompt_chunksize=prompt_chunksize, + temperature=temperature, + top_k=top_k, + top_p=top_p, + eos_id=tokenizer.eos_id, ) t = time.perf_counter() - t0 diff --git a/litgpt/generate/base.py b/litgpt/generate/base.py index 565ef08e23..73e0aad73f 100644 --- a/litgpt/generate/base.py +++ b/litgpt/generate/base.py @@ -75,34 +75,50 @@ def sample( def next_token( model: GPT, - input_pos: torch.Tensor, x: torch.Tensor, - input_pos_maxp1: Optional[int] = None, + input_pos: Optional[int], **sample_kwargs: Dict[str, Any], ) -> torch.Tensor: - logits = model(x, input_pos, input_pos_maxp1=input_pos_maxp1) + logits = model(x, input_pos=input_pos) _next = sample(logits, **sample_kwargs).to(dtype=torch.int64) return _next -def batched_sample(logits: list[torch.Tensor], kwargs: list[dict]) -> torch.Tensor: - assert len(logits) == len(kwargs), "logits and kwargs must have the same length." +def batched_sample( + logits_stack: torch.Tensor, + kwargs: Union[dict, list[dict]], +) -> torch.Tensor: + # Unbind the logits stack into a list of logits. + logits = [logits_stack] if logits_stack.ndim == 1 else logits_stack.unbind(0) + logits = [l.unsqueeze(0) for l in logits] + _kwargs = kwargs if isinstance(kwargs, list) else [kwargs] * len(logits) + assert len(logits) == len(_kwargs), "logits and kwargs must have the same length." return torch.stack( - [sample(l, **sample_args).to(dtype=torch.int64) for sample_args, l in zip(kwargs, logits)], dim=0 + [sample(l, **sample_args).to(dtype=torch.int64) for sample_args, l in zip(_kwargs, logits)], dim=0 ) def batched_next_token( - model: GPT, input_pos: torch.Tensor, x: torch.Tensor, kwargs: Union[dict, list[dict]] + model: GPT, + x: torch.Tensor, + input_pos: Optional[int], + kwargs: Union[dict, list[dict]], ) -> torch.Tensor: - # Where: - # input_pos is a 1d tensor of shape [seq_length...] - # x is context tokens to add to the kvcache. - # For prefill, x is a 2d tensor of shape [batch_size, prompt_length]. - # For subsequent tokens, x is a 2d tensor of shape [batch_size, 1]. - # kwargs is a list of dictionaries, each containing the keyword arguments for the sample function. - # If one dictionary is passed, it's repeated for each sample in the batch. + """ + Args: + model: GPT model. If `input_pos` is not `None`, its KV caches must be + assigned + x: Context tokens to be used as input, shape `(batch_size, num)`. When + used to sample new tokens, we have `num == 1`. + input_pos: Position of `x` in the full sequence. See + :meth:`GPT.forward` + kwargs: Sampling parameters (can be different for each batch dimension) + + Returns: + New samples corresponding to inputs `x` + + """ # In the future, we would like input_pos to be a 2d tensor of shape [batch_size, seq_length]. # That way, we can support prompts of different sizes. # This means making the rope cache and kvcache forward() work with batches. Currently, they do not. @@ -113,17 +129,11 @@ def batched_next_token( # After this problem is resolved, there will be another problem. That being, continuous batched prefill. # If you have any ideas on this, let me know. I don't think that padding input_pos is viable. - _kwargs = kwargs if isinstance(kwargs, list) else [kwargs] * x.size(0) - # Run the model on the batch. - logits_stack = model(x, input_pos) - - # Unbind the logits stack into a list of logits. - logits_list = [logits_stack] if logits_stack.ndim == 1 else logits_stack.unbind(0) - logits_list = [l.unsqueeze(0) for l in logits_list] + logits_stack = model(x, input_pos=input_pos) # Return the next token for each sample in the batch. - return batched_sample(logits_list, kwargs=_kwargs) + return batched_sample(logits_stack, kwargs=kwargs) @torch.inference_mode() @@ -131,6 +141,7 @@ def generate_fn( model: GPT, prompt: torch.Tensor, max_returned_tokens: int, + prompt_chunksize: int = 16, *, temperature: float = 1.0, top_k: Optional[int] = None, @@ -146,6 +157,10 @@ def generate_fn( model: The model to use. prompt: The tokenized prompt to generate from. max_returned_tokens: The maximum number of new tokens to return. Does not include the prompt tokens. + prompt_chunksize: If the prompt is longer than the KV cache length, + prompts are processed in chunks of this size in the prefill phase. + The larger, the faster the prompt is processed, but a large chunk + size may lead to suboptimal cache decisions. temperature: The temp to pass to sample(). top_k: The top_k to pass to sample(). top_p: The top_p to pass to sample(). @@ -155,7 +170,13 @@ def generate_fn( """ prompt_size = prompt.size(0) - device = prompt.device + if prompt_size == 0: + raise ValueError("prompt must not be empty") + sample_kwargs = dict( + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) assert max_returned_tokens > prompt_size, ( f"Not enough space for {prompt_size} prompt tokens in a context length of {max_returned_tokens}." @@ -167,31 +188,44 @@ def generate_fn( if include_prompt: yield prompt + # Prompt processing. The first part of the prompt (possibly all of it) + # is processed with a prefill. If the prompt is larger than the KV + # cache length, we need to use sequential processing after that. + max_prefill_length = model.kv_cache_max_prefill_length() + if max_prefill_length is None: + end = prompt_size + else: + end = min(prompt_size, max_prefill_length) + input_pos = 0 + while input_pos < prompt_size: + inputs = prompt[input_pos:end].view(1, -1) + # We may need the last time slice of `all_logits` below: + all_logits = model(inputs, input_pos=input_pos) + input_pos = end + # Note that `max_tokens_forward` can change during the course of + # prompt processing: + chunksize = min((prompt_chunksize, model.kv_cache_max_tokens_forward(), prompt_size - input_pos)) + end += chunksize + + # Generation loop: One token per iteration + tokens = [] stop_progress = [0] * len(stop_tokens) yielded_idx = 0 - - # Generate output tokens. - # The first token generated is the prefill token. - # The input_pos for this token is the width of the entire prompt. - # For subsequent iterations, it's the index in the context for the token that we're generating. - tokens = [] - token = prompt - prefill_token = True - input_pos = torch.arange(0, prompt_size, device=device, dtype=torch.int64) - # input_pos_maxp1 introduces data-dependent shapes and control flow. - # We want to skip if ThunderModules are involved, either directly or wrapped in LightningModule etc. - input_pos_maxp1 = prompt_size if all(m.__class__.__name__ != "ThunderModule" for m in model.modules()) else None for current_idx in range(max_returned_tokens - prompt_size): # Generate the token - token = next_token( - model, - input_pos, - token.view(1, -1), - input_pos_maxp1=input_pos_maxp1, - temperature=temperature, - top_k=top_k, - top_p=top_p, - ) + if current_idx == 0: + # First token sampled from the final logits output for prompt + # processing + token = sample(all_logits, **sample_kwargs).to(dtype=torch.int64) + all_logits = None + else: + token = next_token( + model=model, + x=token.view(1, -1), + input_pos=input_pos, + **sample_kwargs, + ) + input_pos += 1 tokens.append(token) int_token = token.item() @@ -221,15 +255,6 @@ def generate_fn( yield from y_tokens yielded_idx = safe_idx - # Update input_pos for the next iteration. - if prefill_token: - prefill_token = False - input_pos = torch.tensor([prompt_size], device=device, dtype=torch.int64) - else: - input_pos.add_(1) - if input_pos_maxp1 is not None: - input_pos_maxp1 += 1 - # Yield any remaining tokens if yielded_idx < len(tokens): yield from tokens[yielded_idx:] @@ -242,6 +267,7 @@ def batched_generate_fn( model: GPT, prompts: torch.Tensor, max_returned_tokens: int, + prompt_chunksize: int = 16, *, sample_args: Union[list[dict], dict], stop_tokens: Tuple[List[int], ...] = (), @@ -253,10 +279,18 @@ def batched_generate_fn( Args: model: The model to use. - prompts: A 2D tensor of shape [batch_size, prompt_length]. - max_returned_tokens: The maximum number of tokens to return, including the prompt tokens. - sample_args: The dictionary of kwargs to pass to sample() for each each token for each index in the batch. - stop_tokens: A tuple of stop sequences. If any of the sequences are generated, the generation stops early before max_returned_tokens. + prompts: A 2D tensor of shape [batch_size, prompt_length]. Note that + all prompts need to have the same length (TODO: Relax this) + max_returned_tokens: The maximum number of tokens to return, including + the prompt tokens. + prompt_chunksize: If the prompt is longer than the KV cache length, + prompts are processed in chunks of this size in the prefill phase. + The larger, the faster the prompt is processed, but a large chunk + size may lead to suboptimal cache decisions. + sample_args: The dictionary of kwargs to pass to sample() for each + token for each index in the batch. + stop_tokens: A tuple of stop sequences. If any of the sequences are + generated, the generation stops early before max_returned_tokens. include_prompt: Whether to output the prompt tokens. include_eos: Whether to output the stop tokens if generation stops early. @@ -268,12 +302,11 @@ def batched_generate_fn( prompts = prompts.unsqueeze(0) assert prompts.ndim == 2, "Prompts must be a 2D tensor." - batch_size = prompts.size(0) - max_prompt_size = prompts.size(1) + batch_size, max_prompt_size = prompts.shape device = prompts.device if isinstance(sample_args, dict): - sample_args = [sample_args] * len(prompts) + sample_args = [sample_args] * batch_size else: assert len(sample_args) == batch_size, "sample_args must have the length as the batch size." @@ -290,22 +323,42 @@ def batched_generate_fn( for i in range(max_prompt_size): yield [prompt[i].view(-1) for prompt in prompts] + # Prompt processing. The first part of the prompt (possibly all of it) + # is processed with a prefill. If the prompt is larger than the KV + # cache length, we need to use sequential processing after that. + max_prefill_length = model.kv_cache_max_prefill_length() + if max_prefill_length is None: + end = max_prompt_size + else: + end = min(max_prompt_size, max_prefill_length) + input_pos = 0 + while input_pos < max_prompt_size: + inputs = prompts[:, input_pos:end] + # We may need the last time slice of `all_logits` below: + all_logits = model(inputs, input_pos=input_pos) + input_pos = end + # Note that `max_tokens_forward` can change during the course of + # prompt processing: + chunksize = min((prompt_chunksize, model.kv_cache_max_tokens_forward(), max_prompt_size - input_pos)) + end += chunksize + stop_progresses = [[0] * len(stop_tokens) for _ in range(batch_size)] # [batch_size, ~len(stop_tokens)] stop_idxes = [-1] * batch_size yielded_idx = 0 - # Generate output tokens. - # The first token generated is the prefill token. - # The input_pos for this token is the width of the entire prompt. - # For subsequent iterations, it's the index in the context for the token that we're generating. + # Generation loop: One token per iteration token_lists = [[] for _ in range(batch_size)] - tokens: torch.Tensor = prompts - prefill_token = True - input_pos = torch.arange(0, max_prompt_size, device=device, dtype=torch.int64) for current_idx in range(max_returned_tokens - max_prompt_size): - # Generate the next token for each prompt in the batch. - # This is of shape [batch_size, 1]. - tokens = batched_next_token(model, input_pos, tokens, sample_args) + if current_idx == 0: + tokens = batched_sample(all_logits[:, -1:], kwargs=sample_args) + else: + tokens = batched_next_token( + model=model, + x=tokens, + input_pos=input_pos, + kwargs=sample_args, + ) + input_pos += 1 for i in range(batch_size): token_lists[i].append(tokens[i]) int_tokens = [token.item() for token in tokens] @@ -347,16 +400,6 @@ def batched_generate_fn( yield y_tokens yielded_idx = safe_idx - # Update input_pos for the next iteration. - if prefill_token: - prefill_token = False - - # TODO: Make the model support a batched input_pos of shape [batch_size, 1]. - # The kvcache has been fixed, but the rope cache is still broken. - input_pos = torch.tensor([max_prompt_size], device=device, dtype=torch.int64) - else: - input_pos.add_(1) - # Yield any remaining tokens max_token_lists = max(len(l) for l in token_lists) if yielded_idx < max_token_lists: @@ -375,6 +418,7 @@ def generate( model: GPT, prompt: torch.Tensor, max_returned_tokens: int, + prompt_chunksize: int = 16, *, temperature: float = 1.0, top_k: Optional[int] = None, @@ -390,6 +434,10 @@ def generate( model: The model to use. prompt: Tensor of shape (T) with indices of the prompt sequence. max_returned_tokens: The maximum number of tokens to return (given plus generated). + prompt_chunksize: If the prompt is longer than the KV cache length, + prompts are processed in chunks of this size in the prefill phase. + The larger, the faster the prompt is processed, but a large chunk + size may lead to suboptimal cache decisions. temperature: Scales the predicted logits by 1 / temperature. top_k: If specified, only sample among the tokens with the k highest probabilities. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. @@ -417,6 +465,7 @@ def generate( model=model, prompt=prompt, max_returned_tokens=max_returned_tokens, + prompt_chunksize=prompt_chunksize, temperature=temperature, top_k=top_k, top_p=top_p, @@ -435,6 +484,7 @@ def main( sys_prompt: Optional[str] = None, num_samples: int = 1, max_new_tokens: int = 50, + prompt_chunksize: int = 16, top_k: Optional[int] = 50, top_p: float = 1.0, temperature: float = 0.8, @@ -452,6 +502,10 @@ def main( sys_prompt: The system prompt to use for generating the samples. num_samples: The number of text samples to generate. max_new_tokens: The number of generation steps to take. + prompt_chunksize: If the prompt is longer than the KV cache length, + prompts are processed in chunks of this size in the prefill phase. + The larger, the faster the prompt is processed, but a large chunk + size may lead to suboptimal cache decisions. top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens @@ -540,9 +594,10 @@ def main( for i in range(num_samples): t0 = time.perf_counter() y = generate( - model, - encoded, - max_returned_tokens, + model=model, + prompt=encoded, + max_returned_tokens=max_returned_tokens, + prompt_chunksize=prompt_chunksize, temperature=temperature, top_k=top_k, top_p=top_p, diff --git a/litgpt/generate/full.py b/litgpt/generate/full.py index 78cc8dde7d..eb3e880ad6 100644 --- a/litgpt/generate/full.py +++ b/litgpt/generate/full.py @@ -32,6 +32,7 @@ def main( finetuned_path: Path = Path("out/full/alpaca/lit_model_finetuned.pth"), quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None, max_new_tokens: int = 100, + prompt_chunksize: int = 16, top_k: Optional[int] = 50, top_p: float = 1.0, temperature: float = 0.8, @@ -54,7 +55,12 @@ def main( - bnb.int8: 8-bit quantization from bitsandbytes for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md max_new_tokens: The number of generation steps to take. - top_k: The number of top most probable tokens to consider in the sampling process. + prompt_chunksize: If even the shortest prompt is longer than the KV + cache, prompts are processed in chunks of this size in the + prefill phase. Once the shortest has been processed to the + end, we proceed with chunk size 1. + Defaults to 1, but larger values are recommended for long prompts. + top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens whose cumulative probability exceeds the threshold `top_p`. When specified, @@ -129,7 +135,14 @@ def main( L.seed_everything(1234) t0 = time.perf_counter() y = generate( - model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id + model=model, + prompt=encoded, + max_returned_tokens=max_returned_tokens, + prompt_chunksize=prompt_chunksize, + temperature=temperature, + top_k=top_k, + top_p=top_p, + eos_id=tokenizer.eos_id, ) t = time.perf_counter() - t0 diff --git a/litgpt/generate/sequentially.py b/litgpt/generate/sequentially.py index 1aecccbfc2..ed3976c129 100644 --- a/litgpt/generate/sequentially.py +++ b/litgpt/generate/sequentially.py @@ -23,7 +23,7 @@ import litgpt.generate.base as generate_base from litgpt.config import Config -from litgpt.model import GPT, Block, build_mask_cache +from litgpt.model import GPT, Block from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style from litgpt.tokenizer import Tokenizer from litgpt.utils import check_valid_checkpoint_dir, extend_checkpoint_dir, get_default_supported_precision @@ -64,17 +64,12 @@ def sequential(model: GPT, root: torch.device, max_seq_length: int, devices: int replace_device(submodule, replace=torch.device("cpu"), by=target_device) # in case the checkpoint was partial, materialize leftover metas _materialize_meta_tensors(submodule, target_device) - # and build the kv cache - submodule.attn.kv_cache = submodule.attn.build_kv_cache( - 1, max_seq_length, model.cos.size(-1), target_device - ) # rebuild odd ends with root: + # Setting `max_seq_length` forces other members to be built + if model.max_seq_length == max_seq_length: + model.max_seq_length = max_seq_length + 1 model.max_seq_length = max_seq_length - # the rope cache which is on meta device - model.cos, model.sin = model.rope_cache() - # the mask cache which cannot be created with `set_kv_cache` because that will set it for all layers - model.mask_cache = build_mask_cache(max_seq_length) # and everything that is not a block in the root _materialize_meta_tensors(model, root) replace_device(model, replace=torch.device("cpu"), by=root) @@ -141,6 +136,7 @@ def main( sys_prompt: Optional[str] = None, num_samples: int = 1, max_new_tokens: int = 50, + prompt_chunksize: int = 16, top_k: Optional[int] = 50, top_p: float = 1.0, temperature: float = 0.8, @@ -158,6 +154,11 @@ def main( sys_prompt: The system prompt to use for generating the samples. num_samples: The number of text samples to generate. max_new_tokens: The number of generation steps to take. + prompt_chunksize: If even the shortest prompt is longer than the KV + cache, prompts are processed in chunks of this size in the + prefill phase. Once the shortest has been processed to the + end, we proceed with chunk size 1. + Defaults to 1, but larger values are recommended for long prompts. top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens @@ -227,6 +228,7 @@ def main( # still, use init_tensor for the precision with fabric.init_tensor(), torch.device("meta"): model = GPT(config) + model.set_kv_cache(batch_size=1) print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) t0 = time.perf_counter() @@ -253,26 +255,25 @@ def main( torch._inductor.config.coordinate_descent_tuning = True # cannot use cudagraphs because it doesn't support multiple device indices # https://github.com/pytorch/pytorch/blob/v2.2.0-rc5/torch/_inductor/compile_fx.py#L371-L375 - generate_base.next_token = torch.compile(generate_base.next_token) L.seed_everything(1234) for i in range(num_samples): t0 = time.perf_counter() y = generate_base.generate( - model, - encoded, - max_returned_tokens, + model=model, + prompt=encoded, + max_returned_tokens=max_returned_tokens, + prompt_chunksize=prompt_chunksize, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id, ) t = time.perf_counter() - t0 - for block in model.transformer.h: - block.attn.kv_cache.reset_parameters() print(tokenizer.decode(y)) tokens_generated = y.size(0) - prompt_length print( f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr ) + model.clear_kv_cache() print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr) diff --git a/litgpt/generate/tp.py b/litgpt/generate/tp.py index 16bd5ac878..8e334ff939 100644 --- a/litgpt/generate/tp.py +++ b/litgpt/generate/tp.py @@ -107,6 +107,7 @@ def main( sys_prompt: Optional[str] = None, num_samples: int = 1, max_new_tokens: int = 50, + prompt_chunksize: int = 16, top_k: Optional[int] = 50, top_p: float = 1.0, temperature: float = 0.8, @@ -124,6 +125,11 @@ def main( sys_prompt: The system prompt to use for generating the samples. num_samples: The number of text samples to generate. max_new_tokens: The number of generation steps to take. + prompt_chunksize: If even the shortest prompt is longer than the KV + cache, prompts are processed in chunks of this size in the + prefill phase. Once the shortest has been processed to the + end, we proceed with chunk size 1. + Defaults to 1, but larger values are recommended for long prompts. top_k: The number of top most probable tokens to consider in the sampling process. top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process. In top-p sampling, the next token is sampled from the highest probability tokens @@ -197,6 +203,7 @@ def main( # still, use init_tensor for the precision with fabric.init_tensor(), torch.device("meta"): model = GPT(config) + model.set_kv_cache(batch_size=1) fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) # sequentially do: load the checkpoint on CPU -> quantize -> apply tp -> move to device @@ -236,21 +243,26 @@ def main( torch._dynamo.config.automatic_dynamic_shapes = True torch._inductor.config.triton.unique_kernel_names = True torch._inductor.config.coordinate_descent_tuning = True - generate_base.next_token = torch.compile(generate_base.next_token, mode="reduce-overhead") L.seed_everything(1234) for i in range(num_samples): t0 = time.perf_counter() y = generate_base.generate( - model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id + model=model, + prompt=encoded, + max_returned_tokens=max_returned_tokens, + prompt_chunksize=prompt_chunksize, + temperature=temperature, + top_k=top_k, + top_p=top_p, + eos_id=tokenizer.eos_id, ) t = time.perf_counter() - t0 - for block in model.transformer.h: - block.attn.kv_cache.reset_parameters() fabric.print(tokenizer.decode(y)) tokens_generated = y.size(0) - prompt_length fabric.print( f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr ) + model.clear_kv_cache() if fabric.device.type == "cuda": fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr) diff --git a/litgpt/kvcache/__init__.py b/litgpt/kvcache/__init__.py new file mode 100644 index 0000000000..4d8c403227 --- /dev/null +++ b/litgpt/kvcache/__init__.py @@ -0,0 +1,14 @@ +from litgpt.kvcache.base import ( + DefaultKVCache, + KVCache, + KVCacheParams, +) +from litgpt.kvcache.baselines import DenseKVCache, LastRecentlyInsertedKVCache + +__all__ = [ + "DefaultKVCache", + "DenseKVCache", + "KVCache", + "KVCacheParams", + "LastRecentlyInsertedKVCache", +] diff --git a/litgpt/kvcache/base.py b/litgpt/kvcache/base.py new file mode 100644 index 0000000000..946a22fac5 --- /dev/null +++ b/litgpt/kvcache/base.py @@ -0,0 +1,456 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Tuple + +import torch + +from litgpt.attention import ( + DefaultKeysAndValues, + KeysAndValues, + MultiHeadSelfAttention, +) +from litgpt.config import Config + + +@dataclass(frozen=True) +class KVCacheParams: + batch_size: int + n_query_groups: int + cache_length: int + head_size: int + n_head: int + device: Optional[torch.device] + dtype: Optional[torch.dtype] + + @staticmethod + def from_config( + config: Config, + batch_size: int, + cache_length: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + head_size: Optional[int] = None, + ) -> "KVCacheParams": + if head_size is None: + head_size = config.n_embd // config.n_head + return KVCacheParams( + batch_size=batch_size, + n_query_groups=config.n_query_groups, + cache_length=cache_length, + head_size=head_size, + n_head=config.n_head, + device=device, + dtype=dtype, + ) + + +class KVCache(torch.nn.Module): + """ + Base class for key-value (KV) caches. + + Buffers have shapes + `(batch_size, config.n_query_groups, cache_length, head_size)`, where + `head_size` is a parameter. Caching can be used for + batch size `1 <= eff_batch_size <= batch_size`, which is determined in + prefill calls (`input_pos=0`) of :meth:`forward`. + + Note: In general, query and key tensors need to be position-encoded + (e.g., RoPE). + + """ + + def __init__( + self, + config: Config, + batch_size: int, + cache_length: int, + block_idx: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + head_size: Optional[int] = None, + ): + """ + Note that `batch_size` is the maximum batch size the cache can be used + with. The effective batch size is determined when calling + :meth:`forward` with `input_pos=0`, and can change with any such prefill + call. If this is smaller than `batch_size`, then in general only parts + of the buffers are used. + + Args: + config: Model config + batch_size: Inference batch size (maximum) + cache_length: Number of slots in cache + block_idx: Index of model block (or layer). Multi-head attention + needs to know this. + device: Device for buffers + dtype: Data type for buffers + head_size: Size of final dimension of buffers. Defaults to head + size of model + """ + super().__init__() + if cache_length <= 0: + raise ValueError("cache_length must be positive") + self.batch_size = batch_size + self._n_query_groups = config.n_query_groups + self._cache_length = cache_length + if head_size is None: + head_size = config.head_size + self.head_size = head_size + self.n_head = config.n_head + self._device = device + self._dtype = dtype + self.block_idx = block_idx + # TODO: Remove once HuggingFace bug is fixed + # https://github.com/huggingface/transformers/issues/35233 + # https://github.com/huggingface/transformers/pull/35901 + self._work_around_hf_bug = config.rope_n_elem == 1 + + @property + def device(self) -> torch.device: + return self._device + + @property + def dtype(self) -> torch.dtype: + return self._dtype + + @property + def cache_length(self) -> Optional[int]: + return self._cache_length + + @property + def n_query_groups(self) -> int: + return self._n_query_groups + + @property + def next_token_pos(self) -> Optional[int]: + """ + Returns: + Input position for next token to be generated, or `None` if cache + has not been initialized yet (call of :meth:`prefill`). + """ + raise NotImplementedError() + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + token_idx: torch.Tensor, + input_pos: int, + ) -> torch.Tensor: + """ + Given query, key, value tensors, this method extends the KV cache with + `key`, `value`, then computes multi-head self attention. There are two + cases: + + * Prefill (`input_pos == 0`): Starts a generation loop by passing key + and value tensors. The KV cache is reset. The length must be + `num <= max_prefill_length`. The effective batch size must be + `eff_batch_size <= batch_size`. This batch size is then fixed for + subsequent calls of :meth:`forward`. + * Update (`input_pos > 0`): Continues a generation loop (or processing + of large prompt). The length must be `num <= max_tokens_forward`. + + If the cache makes eviction decisions based on scores which require + attention weights, scores for the next :meth:`forward` call need to + be computed here. + + If a sequence is generated token by token, updates always use `num=1`. + The case `num > 1` arises if large prompts are to be ingested with more + than `max_prefill_length` tokens. Note that if the cache makes eviction + decisions by scoring in :meth:`update`, then large `num` may lead to + worse decisions. On the other hand, ingesting prompts with larger `num` + is faster. + + Args: + query: New queries, + `(eff_batch_size, n_query_groups, num, head_size)`. Here, + `num <= max_tokens_forward` if `input_pos > 0`, and + `num <= max_prefill_length` if `input_pos == 0`. Must be + position encoded. + key: New keys, `(eff_batch_size, n_query_groups, num, head_size)`. + Must be position encoded. + value: New values, `(eff_batch_size, n_query_groups, num, head_size)` + token_idx: Token indices of input sequence, `(eff_batch_size, num)`. + Some KV caches make use of this information. + input_pos: Token position of the new chunk in the full input + sequence. + + Returns: + Multi-head self-attention outputs before final linear map, + `(eff_batch_size, n_head, num, head_size)` + + """ + raise NotImplementedError() + + def get_keys_values(self) -> Optional[KeysAndValues]: + """ + Returns: + :class:`KeysAndValues` object, providing access to currently stored + keys and values tensors. If the cache is empty or has not been + initialized, `None` is returned. + + """ + raise NotImplementedError() + + @property + def max_tokens_forward(self) -> int: + """ + Note that this limit may change during the course of the generation + for certain caches. + + Returns: + Maximum number of token positions which can be treated in + :meth:`forward` with `input_pos > 0`. Depends on cache, but is + `<= cache_length` + + """ + raise NotImplementedError() + + @property + def max_prefill_length(self) -> Optional[int]: + """ + Returns: + Maximum sequence length for `key`, `value` tensors passed to + :meth:`forward` if `input_pos == 0`. If there is no such maximum + length, `None` is returned. + + """ + raise NotImplementedError() + + def get_params(self) -> KVCacheParams: + return KVCacheParams( + batch_size=self.batch_size, + n_query_groups=self.n_query_groups, + cache_length=self.cache_length, + head_size=self.head_size, + n_head=self.n_head, + device=self.device, + dtype=self.dtype, + ) + + def token_positions(self) -> torch.Tensor: + """ + Returns: + Token positions in slots of the cache, shape + `(eff_batch_size, n_query_groups, T)`.where `T <= cache_length` + is the current cache length. + """ + raise NotImplementedError() + + def size_estimate(self) -> Tuple[int, Dict[str, int]]: + """ + This is an estimate of the main buffers (which should all be allocated + up front), it does not cover temporary storage used in the methods + (make sure these are small compared to the main buffers). Also, real + memory usage may be larger due to alignment issues. + + Returns: + num_bits_total, bits_by_part (unit is bit) + + """ + raise NotImplementedError() + + @classmethod + def size_estimate_apriori(cls, params: KVCacheParams, **kwargs) -> Tuple[int, Dict[str, int]]: + """ + Same semantics as :meth:`size_estimate`, but can be called without a + cache being created. Results may not be exactly the same, but should + be very close. + + Args: + params: KV cache parameters + **kwargs: Extra arguments (optional) + + Returns: + num_bits_total, bits_by_part (unit is bit) + + """ + raise NotImplementedError() + + def reset_parameters(self) -> None: + pass + + +class DefaultKVCache(KVCache): + """ + Default implementation of :class:`KVCache`, which implements :meth:`forward` + using scaled dot product attention. Most KV caches will inherit from this + class. + + """ + + def __init__( + self, + config: Config, + batch_size: int, + cache_length: int, + block_idx: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + head_size: Optional[int] = None, + ): + super().__init__( + config=config, + batch_size=batch_size, + cache_length=cache_length, + block_idx=block_idx, + device=device, + dtype=dtype, + head_size=head_size, + ) + self.mha = MultiHeadSelfAttention(config) + + @property + def eff_batch_size(self) -> Optional[int]: + raise NotImplementedError() + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + token_idx: torch.Tensor, + input_pos: int, + ) -> torch.Tensor: + for_prefill = input_pos == 0 + if query.ndim != 4: + raise ValueError("query, key, value must be 4D tensors") + eff_batch_size, _, num, _ = query.shape + if for_prefill: + if not (1 <= eff_batch_size <= self.batch_size): + raise ValueError(f"query.shape[0] = {eff_batch_size}, must be in [1, {self.batch_size}]") + if self.max_prefill_length is not None and not (1 <= num <= self.max_prefill_length): + raise ValueError(f"query.shape[2] = {num}, must be in [1, {self.max_prefill_length}]") + else: + if eff_batch_size != self.eff_batch_size: + raise ValueError(f"query.shape[0] = {eff_batch_size} != eff_batch_size = {self.eff_batch_size}") + if not (1 <= num <= self.max_tokens_forward): + raise ValueError(f"query.shape[2] = {num}, must be in [1, {self.max_tokens_forward}]") + q_shape = (eff_batch_size, self.n_head, num, self.head_size) + if query.shape != q_shape: + raise ValueError(f"query.shape = {query.shape}, must be {q_shape}") + k_shape = (eff_batch_size, self.n_query_groups, num, self.head_size) + if key.shape != k_shape: + raise ValueError(f"key.shape = {key.shape}, must be {k_shape}") + if value.shape != k_shape: + raise ValueError(f"value.shape = {value.shape}, must be {k_shape}") + t_shape = (eff_batch_size, num) + if token_idx.shape != t_shape: + raise ValueError(f"token_idx.shape = {token_idx.shape}, must be {t_shape}") + self.mha.set_seq_length(input_pos + num, device=query.device) + + # Call :meth:`_forward` or :meth:`_prefill`, depending on `for_prefill` + if for_prefill: + self._prefill(key, value, token_idx) + # In this case, `k_and_v` can vend both keys and values at the same + # time. + k_and_v = DefaultKeysAndValues(key, value) + else: + # Extend KV cache and retrieve key, value tensors to be used. + # Instead of asking for the key and value tensors as such, + # `k_and_v` allows access to them. Since they are never needed at + # the same time, this can save memory. + k_and_v = self._forward(key, value, token_idx) + + # Multi-head self-attention main computation + return_attn_weights = self.update_requires_attn_weights() + y, scores = self.mha( + query=query, + k_and_v=k_and_v, + block_idx=self.block_idx, + input_pos=input_pos, + return_attn_weights=return_attn_weights, + token_positions=self.token_positions(), + ) + if scores is not None and return_attn_weights: + # Pass attention weights to KV cache + self._update(attn_weights=scores) + + return y + + def _forward( + self, + key: torch.Tensor, + value: torch.Tensor, + token_idx: torch.Tensor, + ) -> KeysAndValues: + """ + Implements part of :meth:`forward` if `input_pos > 0`. Namely, `key` + and `value` are written into the cache, possibly evicting slots. Then, + an object is returned which provides read access to the full keys and + values buffers. + + Args: + key: New keys, `(eff_batch_size, n_query_groups, num, head_size)`, + where `1 <= num <= max_tokens_forward` + value: New values, `(eff_batch_size, n_query_groups, num, head_size)` + token_idx: Token indices of input sequence, `(eff_batch_size, num)`. + Some KV caches make use of this information. + + Returns: + key_cached, value_cached, `(eff_batch_size, n_query_groups, T, + head_size)`, where `T <= cache_length` is the current cache + length + + """ + raise NotImplementedError() + + def _update(self, *args, **kwargs): + """ + Method called in :meth:`forward` if `input_pos > 0`, passing extra + information depending on the subclass. In general, this method updates + internal scores and takes a decision which slot is evicted upon the + next :meth:`forward` call, if the cache is full. + + One important example are KV caches based on the Heavy Hitter Oracle + (H2O) proposal. These require the attention weights from the current + MLA computation to be passed, and :meth:`update_requires_attn_weights` + has to return `True`. + + Note: The extra information typically scales with `num`, the number of + tokens :meth:`forward` was called for. + + Args: + *args: Depends on subclass + **kwargs: Depends on subclass + + """ + raise NotImplementedError() + + def update_requires_attn_weights(self) -> bool: + """ + Attention weights are required for KV caches following the Heavy + Hitter Oracle (H2O) proposal. + + Returns: + If `True`, :meth:`update` requires argument `attn_weights`, which + passes current attention weights as + `(eff_batch_size, n_query_groups, num, T)` tensor, where + `T <= cache_length` is the current cache length, and `num` is the + number of tokens in the last recent :meth:`forward` call. + + """ + return False + + def _prefill( + self, + key: torch.Tensor, + value: torch.Tensor, + token_idx: torch.Tensor, + ): + """ + Implements :meth:`forward` for `input_pos=0`. + Starts a generation loop by passing key and value tensors coming from + a prefill with embeddings coming from the prompts. The length must be + `T <= max_prefill_length`. The effective batch size must be + `eff_batch_size <= batch_size`. This batch size is then fixed for + subsequent calls of :meth:`forward` and :meth:`update`. + + Args: + key: Prefill keys, `(eff_batch_size, n_query_groups, T, head_size)` + value: Prefill values, `(eff_batch_size, n_query_groups, T, head_size)` + token_idx: Token indices of input sequence, `(eff_batch_size, T)`. + Some KV caches make use of this information. + + """ + raise NotImplementedError() diff --git a/litgpt/kvcache/baselines.py b/litgpt/kvcache/baselines.py new file mode 100644 index 0000000000..89affeaefd --- /dev/null +++ b/litgpt/kvcache/baselines.py @@ -0,0 +1,356 @@ +from typing import Dict, Optional, Tuple + +import torch + +from litgpt.attention import DefaultKeysAndValues, KeysAndValues +from litgpt.config import Config +from litgpt.kvcache import DefaultKVCache, KVCacheParams +from litgpt.kvcache.utils import bits_for_torch_dtype, bitsize_of + + +class DenseKVCache(DefaultKVCache): + """ + Key-value cache for dense attention. Key and value tensors for all + past tokens are maintained. The cache length is the maximum sequence + length. This cache requires a lot of memory, it can only be used for + moderate cache lengths. + + Note: If the cache is full, :meth:`forward` raises an exception. The cache + buffers are allocated up front and are not enlarged later on. + + """ + + def __init__( + self, + config: Config, + batch_size: int, + block_idx: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: Optional[int] = None, + head_size: Optional[int] = None, + ): + """ + Args: + config: Model config + batch_size: Inference batch size + device: Device for buffers + dtype: Data type for buffers + max_sequence_length: Cache length. If not given, we use + `config.block_size` + head_size: Size of final dimension of buffers. Defaults to head + size of model + + """ + if max_sequence_length is None: + max_sequence_length = config.block_size + super().__init__( + config=config, + batch_size=batch_size, + cache_length=max_sequence_length, + block_idx=block_idx, + device=device, + dtype=dtype, + head_size=head_size, + ) + shape = (batch_size, self.n_query_groups, max_sequence_length, self.head_size) + self.register_buffer("v", torch.zeros(shape, device=device, dtype=dtype), persistent=False) + # TODO: Remove once HF bug fixed + if self._work_around_hf_bug: + shape = shape[:-1] + (self.head_size + 1,) + self.register_buffer("k", torch.zeros(shape, device=device, dtype=dtype), persistent=False) + self.next_position = None + self._eff_batch_size = None + + @property + def eff_batch_size(self) -> Optional[int]: + return self._eff_batch_size + + @property + def next_token_pos(self) -> Optional[int]: + return self.next_position + + @property + def max_tokens_forward(self) -> int: + return self.cache_length + + @property + def max_prefill_length(self) -> Optional[int]: + return self.cache_length + + @property + def current_length(self) -> int: + return self.next_position + + def get_keys_values(self) -> Optional[KeysAndValues]: + if self.eff_batch_size is None or self.next_position is None: + return None + else: + return DefaultKeysAndValues( + self.k[: self.eff_batch_size, :, : self.next_position, :], + self.v[: self.eff_batch_size, :, : self.next_position, :], + ) + + def _forward( + self, + key: torch.Tensor, + value: torch.Tensor, + token_idx: torch.Tensor, + ) -> KeysAndValues: + if self.next_position is None: + raise IndexError("Cache needs to be initialized with 'prefill' before being used") + num = key.shape[2] + if not 1 <= num <= self.max_tokens_forward: + raise ValueError(f"key.shape[2] = {num}, must be in [1, {self.max_tokens_forward}]") + np = self.next_position + if np + num > self.cache_length: + raise IndexError(f"Cache has at most {self.cache_length - np} free slots, cannot add {num} entries") + shape = (self.eff_batch_size, self.n_query_groups, num, self.head_size) + # TODO: Remove once HF bug fixed + if self._work_around_hf_bug: + assert value.shape == shape + shape = shape[:-1] + (self.head_size + 1,) + assert key.shape == shape + elif key.shape != shape or value.shape != shape: + raise ValueError( + f"Shapes of key, value must be {shape}, but key.shape = {key.shape}, value.shape = {value.shape}" + ) + if key.dtype != value.dtype: + raise ValueError(f"key.dtype = {key.dtype} != {value.dtype} = value.dtype") + # Move the buffer to the activation dtype for when AMP is used + # TODO: Is this needed? Other KV caches do not support changing + # `dtype` after creation. + if key.dtype != self.dtype: + self._dtype = key.dtype + self.k = self.k.to(self.dtype) + self.v = self.v.to(self.dtype) + # Append new content to cache + self.k[: self.eff_batch_size, :, np : (np + num), :] = key + self.v[: self.eff_batch_size, :, np : (np + num), :] = value + self.next_position += num + return self.get_keys_values() + + def _update(self, *args, **kwargs): + pass + + def _prefill( + self, + key: torch.Tensor, + value: torch.Tensor, + token_idx: torch.Tensor, + ): + if key.dim() != 4: + raise ValueError("key must have 4 dimensions") + init_length = key.shape[2] + if init_length > self.cache_length: + raise ValueError(f"key.shape[2] = {init_length}, must be at most {self.cache_length}") + eff_batch_size = key.shape[0] + if eff_batch_size > self.batch_size: + raise ValueError(f"key.shape[0] = {eff_batch_size} must be at most batch_size = {self.batch_size}") + shape = (eff_batch_size, self.n_query_groups, init_length, self.head_size) + # TODO: Remove once HF bug fixed + if self._work_around_hf_bug: + assert value.shape == shape + shape = shape[:-1] + (self.head_size + 1,) + assert key.shape == shape + elif key.shape != shape or value.shape != shape: + raise ValueError( + f"Shapes of key, value must be {shape}, but key.shape = {key.shape}, value.shape = {value.shape}" + ) + # Initialize cache content + self.k = self.k.to(key.dtype) + self.v = self.v.to(value.dtype) + self.k[:eff_batch_size, :, :init_length, :] = key + self.v[:eff_batch_size, :, :init_length, :] = value + self.next_position = init_length + self._eff_batch_size = eff_batch_size + + def token_positions(self) -> torch.Tensor: + return ( + torch.arange(self.next_position, device=self.device) + .reshape(1, 1, -1) + .expand(self.eff_batch_size, self.n_query_groups, -1) + ) + + def size_estimate(self) -> Tuple[int, Dict[str, int]]: + sz_buffs = bitsize_of(self.k) + bitsize_of(self.v) + return sz_buffs, dict(buffers=sz_buffs) + + @classmethod + def size_estimate_apriori(cls, params: KVCacheParams, **kwargs) -> Tuple[int, Dict[str, int]]: + cache_length = params.cache_length + dtype = params.dtype + if dtype is None: + raise ValueError("params.dtype must be provided") + numel = params.batch_size * params.n_query_groups * cache_length * params.head_size + sz_buffs = 2 * numel * bits_for_torch_dtype(dtype) + return sz_buffs, dict(buffers=sz_buffs) + + +class LastRecentlyInsertedKVCache(DefaultKVCache): + """ + Baseline key-value cache which stores the last recently inserted + `cache_length` key, value tensors. + """ + + def __init__( + self, + config: Config, + batch_size: int, + cache_length: int, + block_idx: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + head_size: Optional[int] = None, + ): + super().__init__( + config=config, + batch_size=batch_size, + cache_length=cache_length, + block_idx=block_idx, + device=device, + dtype=dtype, + head_size=head_size, + ) + shape = (batch_size, self.n_query_groups, cache_length, self.head_size) + self.register_buffer("v", torch.zeros(shape, device=device, dtype=dtype), persistent=False) + # TODO: Remove once HF bug fixed + if self._work_around_hf_bug: + shape = shape[:-1] + (self.head_size + 1,) + self.register_buffer("k", torch.zeros(shape, device=device, dtype=dtype), persistent=False) + self.register_buffer("token_pos", torch.zeros(cache_length, device=device, dtype=torch.int), persistent=False) + self.next_position = None + self._eff_batch_size = None + self.current_length = None + self._next_token_pos = None + + @property + def eff_batch_size(self) -> Optional[int]: + return self._eff_batch_size + + @property + def next_token_pos(self) -> Optional[int]: + return self._next_token_pos + + @property + def max_tokens_forward(self) -> int: + return self.cache_length + + @property + def max_prefill_length(self) -> Optional[int]: + return None + + def get_keys_values(self) -> Optional[KeysAndValues]: + if self.eff_batch_size is None or self.current_length is None: + return None + else: + return DefaultKeysAndValues( + self.k[: self.eff_batch_size, :, : self.current_length, :], + self.v[: self.eff_batch_size, :, : self.current_length, :], + ) + + def _forward( + self, + key: torch.Tensor, + value: torch.Tensor, + token_idx: torch.Tensor, + ) -> KeysAndValues: + if self.next_position is None: + raise IndexError("Cache needs to be initialized with 'prefill' before being used") + if key.ndim != 4: + raise ValueError(f"key must be a 4D tensor, but has shape {key.shape}") + num = key.shape[2] + if not 1 <= num <= self.max_tokens_forward: + raise ValueError(f"key.shape[2] = {num}, must be in [1, {self.max_tokens_forward}]") + shape = (self.eff_batch_size, self.n_query_groups, num, self.head_size) + # TODO: Remove once HF bug fixed + if self._work_around_hf_bug: + assert value.shape == shape + shape = shape[:-1] + (self.head_size + 1,) + assert key.shape == shape + elif key.shape != shape or value.shape != shape: + raise ValueError( + f"Shapes of key, value must be {shape}, but key.shape = {key.shape}, value.shape = {value.shape}" + ) + # Move the buffer to the activation dtype for when AMP is used + self.k = self.k.to(key.dtype) + self.v = self.v.to(value.dtype) + # Append new content to cache + np = self.next_position + num1 = min(num, self.cache_length - np) + self.k[: self.eff_batch_size, :, np : (np + num1), :] = key[:, :, :num1, :] + self.v[: self.eff_batch_size, :, np : (np + num1), :] = value[:, :, :num1, :] + ntp = self._next_token_pos + self.token_pos[np : (np + num1)] = torch.arange(ntp, ntp + num1, device=self.device, dtype=torch.int) + if num1 < num: + diff = num - num1 + self.k[: self.eff_batch_size, :, :diff, :] = key[:, :, num1:, :] + self.v[: self.eff_batch_size, :, :diff, :] = value[:, :, num1:, :] + self.token_pos[:diff] = torch.arange(ntp + num1, ntp + num, device=self.device, dtype=torch.int) + self.next_position = (np + num) % self.cache_length + self.current_length = min(self.current_length + num, self.cache_length) + self._next_token_pos += num + return self.get_keys_values() + + def _update(self, *args, **kwargs): + pass + + def _prefill( + self, + key: torch.Tensor, + value: torch.Tensor, + token_idx: torch.Tensor, + ): + if key.dim() != 4: + raise ValueError("key must have 4 dimensions") + init_length = key.shape[2] + eff_init_length = min(init_length, self.cache_length) + eff_batch_size = key.shape[0] + if eff_batch_size > self.batch_size: + raise ValueError(f"key.shape[0] = {eff_batch_size} must be at most batch_size = {self.batch_size}") + shape = (eff_batch_size, self.n_query_groups, init_length, self.head_size) + # TODO: Remove once HF bug fixed + if self._work_around_hf_bug: + assert value.shape == shape + shape = shape[:-1] + (self.head_size + 1,) + assert key.shape == shape + elif key.shape != shape or value.shape != shape: + raise ValueError( + f"Shapes of key, value must be {shape}, but key.shape = {key.shape}, value.shape = {value.shape}" + ) + # Initialize cache content + self.k = self.k.to(key.dtype) + self.v = self.v.to(value.dtype) + self.k[:eff_batch_size, :, :eff_init_length, :] = key[:, :, -eff_init_length:, :] + self.v[:eff_batch_size, :, :eff_init_length, :] = value[:, :, -eff_init_length:, :] + self.token_pos[:eff_init_length] = torch.arange( + init_length - eff_init_length, + init_length, + dtype=self.token_pos.dtype, + device=self.token_pos.device, + ) + self.current_length = eff_init_length + self._next_token_pos = init_length + self.next_position = eff_init_length % self.cache_length + self._eff_batch_size = eff_batch_size + + def token_positions(self) -> torch.Tensor: + return ( + self.token_pos[: self.current_length].reshape(1, 1, -1).expand(self.eff_batch_size, self.n_query_groups, -1) + ) + + def size_estimate(self) -> Tuple[int, Dict[str, int]]: + sz_buffs = bitsize_of(self.k) + bitsize_of(self.v) + sz_pos = bitsize_of(self.token_pos) + return sz_buffs + sz_pos, dict(buffers=sz_buffs, token_pos=sz_pos) + + @classmethod + def size_estimate_apriori(cls, params: KVCacheParams, **kwargs) -> Tuple[int, Dict[str, int]]: + cache_length = params.cache_length + dtype = params.dtype + if dtype is None: + raise ValueError("params.dtype must be provided") + numel = params.batch_size * params.n_query_groups * cache_length * params.head_size + k_and_v = 2 * numel * bits_for_torch_dtype(dtype) + tk_p = cache_length * bits_for_torch_dtype(torch.int) + return k_and_v + tk_p, dict(buffers=k_and_v, token_pos=tk_p) diff --git a/litgpt/kvcache/test_utils.py b/litgpt/kvcache/test_utils.py new file mode 100644 index 0000000000..9dc0989ec3 --- /dev/null +++ b/litgpt/kvcache/test_utils.py @@ -0,0 +1,77 @@ +from typing import Tuple + +import torch + +from litgpt.config import Config +from litgpt.kvcache.base import KVCache, KVCacheParams +from litgpt.kvcache.baselines import DenseKVCache, LastRecentlyInsertedKVCache + +KV_CACHE_NAMES = ( + "dense-default", + "lastrec-default", +) + + +def create_kv_cache( + name: str, + params: KVCacheParams, + block_idx: int = 0, +) -> KVCache: + config = Config( + n_embd=params.n_head * params.head_size, + n_head=params.n_head, + n_query_groups=params.n_query_groups, + ) + from_config_kwargs = dict( + config=config, + batch_size=params.batch_size, + block_idx=block_idx, + device=params.device, + dtype=params.dtype, + ) + + result = None + if name == "dense-default": + result = DenseKVCache(**from_config_kwargs) + elif name == "lastrec-default": + result = LastRecentlyInsertedKVCache(**from_config_kwargs, cache_length=params.cache_length) + + if result is None: + raise ValueError(f"name = {name} not supported") + return result + + +def tensor_is_simple(x: torch.Tensor) -> bool: + assert x.ndim > 1 + x = x.view(-1, x.shape[-1]) + other = x[0].unsqueeze(0).expand(*x.shape) + return x.equal(other) + + +def random_tensor( + params: KVCacheParams, + num: int, +) -> torch.Tensor: + shape = (params.batch_size, params.n_query_groups, num, params.head_size) + return torch.randn(*shape, device=params.device, dtype=params.dtype) + + +def random_keys_values( + params: KVCacheParams, + num: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + keys = random_tensor(params, num) + values = random_tensor(params, num) + return keys, values + + +def random_attn_weights( + params: KVCacheParams, + num: int, +) -> torch.Tensor: + attn_weights = torch.randn( + (params.batch_size, params.n_head, num), + device=params.device, + dtype=params.dtype, + ) + return torch.nn.functional.softmax(attn_weights, dim=-1) diff --git a/litgpt/kvcache/utils.py b/litgpt/kvcache/utils.py new file mode 100644 index 0000000000..d618cd8202 --- /dev/null +++ b/litgpt/kvcache/utils.py @@ -0,0 +1,17 @@ +import torch + + +def bits_for_torch_dtype(dtype: torch.dtype) -> int: + """ + Args: + dtype: Torch data type + + Returns: + Number of bits used to represent one number of this type. + + """ + return torch.tensor([], dtype=dtype).element_size() * 8 + + +def bitsize_of(x: torch.Tensor) -> int: + return x.numel() * x.element_size() * 8 diff --git a/litgpt/lora.py b/litgpt/lora.py index 0277412c67..1c1c53dc0a 100644 --- a/litgpt/lora.py +++ b/litgpt/lora.py @@ -53,7 +53,9 @@ from typing_extensions import Self import litgpt +from litgpt.attention import MultiHeadSelfAttention from litgpt.config import Config as BaseConfig +from litgpt.kvcache.base import KVCache from litgpt.model import GPT as BaseModel from litgpt.model import Block as BaseBlock from litgpt.model import CausalSelfAttention as BaseCausalSelfAttention @@ -496,8 +498,11 @@ def __init__(self, config: Config) -> None: ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), ) ) - self.mask_cache: Optional[torch.Tensor] = None + self.mha = MultiHeadSelfAttention(config) self.max_seq_length = self.config.block_size + self._start_of_layer_hook = config.start_of_layer_hook + # Have dense KV caches been created by `set_kv_cache`? + self._default_kv_cache = False @classmethod def from_name(cls, name: str, **kwargs: Any) -> Self: @@ -517,15 +522,25 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa class Block(BaseBlock): - def __init__(self, config: Config, block_idx: int) -> None: - super().__init__(config, block_idx) - self.attn = CausalSelfAttention(config, block_idx) + def __init__( + self, + config: Config, + block_idx: int, + kv_cache: Optional[KVCache] = None, + ) -> None: + super().__init__(config, block_idx, kv_cache) + self.attn = CausalSelfAttention(config, block_idx, kv_cache=kv_cache) self.mlp = config.mlp_class(config) class CausalSelfAttention(BaseCausalSelfAttention): - def __init__(self, config: Config, block_idx: int) -> None: - super().__init__(config, block_idx) + def __init__( + self, + config: Config, + block_idx: int, + kv_cache: Optional[KVCache] = None, + ) -> None: + super().__init__(config, block_idx, kv_cache) # key, query, value projections for all heads, but in a batch shape = (config.n_head + 2 * config.n_query_groups) * config.head_size self.qkv = LoRAQKVLinear( diff --git a/litgpt/model.py b/litgpt/model.py index 24d952340e..197f781745 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -6,7 +6,7 @@ https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model. """ -import math +from dataclasses import replace from functools import partial from typing import Any, List, Optional, Tuple, Union @@ -15,17 +15,37 @@ import torch.nn.functional as F from typing_extensions import Self -from litgpt.config import Config +from litgpt.attention import ( + DefaultKeysAndValues, + MultiHeadSelfAttention, + do_softcapping, +) +from litgpt.config import Config, StartOfLayerHook +from litgpt.kvcache import ( + DenseKVCache, + KVCache, + KVCacheParams, +) from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble +from litgpt.utils import batched_index_select class GPT(nn.Module): def __init__(self, config: Config) -> None: + """ + Args: + config: Configuration parameters + + """ super().__init__() assert config.padded_vocab_size is not None self.config = config - self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) + self.lm_head = nn.Linear( + config.n_embd, + config.padded_vocab_size, + bias=config.lm_head_bias, + ) self.transformer = nn.ModuleDict( dict( wte=nn.Embedding(config.padded_vocab_size, config.n_embd), @@ -33,8 +53,11 @@ def __init__(self, config: Config) -> None: ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), ) ) - self.mask_cache: Optional[torch.Tensor] = None - self.max_seq_length = self.config.block_size + self.mha = MultiHeadSelfAttention(config) + self.max_seq_length = config.block_size + self._start_of_layer_hook = config.start_of_layer_hook + # Have dense KV caches been created by `set_kv_cache`? + self._default_kv_cache = False @property def max_seq_length(self) -> int: @@ -43,8 +66,13 @@ def max_seq_length(self) -> int: @max_seq_length.setter def max_seq_length(self, value: int) -> None: """ - When doing inference, the sequences used might be shorter than the model's context length. - This allows setting a smaller number to avoid allocating unused memory + When doing inference, the sequences used might be shorter than the + model's context length. This allows setting a smaller number to avoid + allocating unused memory. + + If KV caches are of type `DenseKVCache`, and they are too small to hold + `value` entries, a warning message is printed. + """ if value > self.config.block_size: raise ValueError( @@ -52,24 +80,159 @@ def max_seq_length(self, value: int) -> None: " This is likely because the input text exceeds the supported context length of this model." ) self._max_seq_length = value + # RoPE cache: + # `cos`, `sin` of shape `(max_seq_length, config.rope_n_elem)` + # More precisely, the RoPE cache is recomputed only if + # `max_seq_length` increases. + # Note: The RoPE cache is independent of KV caches, since positional + # encoding is done (on query and key vectors) before the KV cache + # gets involved (and the KV cache stores encoded key tensors). if not hasattr(self, "cos"): # first call cos, sin = self.rope_cache() self.register_buffer("cos", cos, persistent=False) self.register_buffer("sin", sin, persistent=False) - # override - elif value != self.cos.size(0): + elif self.cos.size(0) < value: self.cos, self.sin = self.rope_cache(device=self.cos.device) - # the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know - # if the kv cache is expected - if self.mask_cache is not None and self.mask_cache.shape[-1] < value: - print( - f"Warning: KV cache has length {self.mask_cache.shape[-1]} < {value} = max_seq_length. Call 'set_kv_cache' before doing any forwards!" - ) + # KV caches + # We do not change them here, but output a warning if default caches are + # too small + for l_ix, block in enumerate(self.transformer.h): + attn = block.attn + kv_cache = attn.kv_cache + if kv_cache is not None and isinstance(kv_cache, DenseKVCache) and kv_cache.cache_length < value: + print( + f"KV cache for layer {l_ix} too small: Call 'set_kv_cache(batch_size={kv_cache.batch_size}, max_seq_length={value}) before inference" + ) + break + # Multi-head attention + self.mha.set_seq_length(value, device=self.cos.device) + + def are_kv_caches_assigned(self) -> bool: + status = [block.attn.kv_cache is not None for block in self.transformer.h] + result = any(status) + if result and not all(status): + raise IndexError("Some layers have KV caches assigned, but not all") + return result + + def assign_kv_caches(self, kv_caches: List[KVCache]): + """ + Assigns specific KV caches to the multi-head attention blocks + of each layer. This can only be done if no caches have been + assigned or created (see :meth:`set_kv_cache`) before. + + KV caches are required for inference (i.e., calling :meth:`forward` with + `input_pos` argument). If no KV caches are assigned, inference calls + fail. + + Args: + kv_caches: KV caches, one for each layer of the model + + """ + if self.are_kv_caches_assigned(): + raise ValueError("Model has KV caches assigned already") + if len(kv_caches) != self.config.n_layer: + raise ValueError(f"kv_caches must have one entry per layer, so {self.config.n_layer} entries ") + for cache, block in zip(kv_caches, self.transformer.h): + self._check_kv_cache(self.config, cache) + block.attn.kv_cache = cache + + def set_kv_cache( + self, + batch_size: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_seq_length: Optional[int] = None, + ): + """ + This method can be called only if KV caches have not been assigned + with :meth:`assign_kv_caches`. It creates default (dense) KV caches + for every layer. These may require a lot of memory. If this is an + issue, consider :meth:`assign_kv_caches` with KV caches of restricted + size. + + KV caches are required for inference (i.e., calling :meth:`forward` with + `input_pos` argument). If no KV caches are assigned, inference calls + fail. + + Args: + batch_size: Inference batch size + device: Device for buffers + dtype: Data type for buffers + max_seq_length: Cache length. If not given, we use + `self.max_seq_length` + + """ + if self.are_kv_caches_assigned() and not self._default_kv_cache: + raise ValueError("Model has KV caches assigned already") + if max_seq_length is None: + max_seq_length = self.max_seq_length + for block in self.transformer.h: + attn = block.attn + kv_cache = attn.kv_cache + if ( + kv_cache is None + or kv_cache.batch_size != batch_size + or kv_cache.cache_length != max_seq_length + or kv_cache.device != device + or kv_cache.dtype != dtype + ): + if kv_cache is not None: + device = kv_cache.device if device is None else device + dtype = kv_cache.dtype if dtype is None else dtype + attn.create_default_kv_cache( + batch_size=batch_size, + device=device, + dtype=dtype, + max_sequence_length=max_seq_length, + ) + self._default_kv_cache = True def reset_parameters(self) -> None: # Trigger resetting the rope-cache self.cos, self.sin = self.rope_cache(device=self.cos.device) + self.mha.set_seq_length(self.max_seq_length, device=self.cos.device) + + def set_start_of_layer_hook( + self, + hook: Optional[StartOfLayerHook], + ): + """ + Sets a function `hook(x, block_idx, input_pos)`, which is called + in :meth:`forward` at the start of each layer. Here, `x` is the + layer input, `block_idx` the number of the layer, and `input_pos` + the position in the sequence. The hook is called with the output + of the final layer (input of head model), where + `block_idx=self.config.n_layer`. + + The default start of layer hook is `self.config.start_of_layer_hook`. + This is overwritten here. + + Args: + hook: Hook function to be set, or `None` to remove hook + + """ + self._start_of_layer_hook = hook + + @staticmethod + def _check_kv_cache( + config: Config, + kv_cache: KVCache, + ): + params = kv_cache.get_params() + if config.n_query_groups != params.n_query_groups: + raise ValueError( + f"config and kv_cache not compatible: config.n_query_groups = {config.n_query_groups} != {params.n_query_groups} = kv_cache.n_query_groups" + ) + if config.n_head != params.n_head: + raise ValueError( + f"config and kv_cache not compatible: config.n_head = {config.n_head} != {params.n_head} = kv_cache.n_head" + ) + head_size = config.n_embd // config.n_head + if head_size != params.head_size: + raise ValueError( + f"config and kv_cache not compatible: config.head_size = {head_size} != {params.head_size} = kv_cache.head_size" + ) def _init_weights(self, module: nn.Module) -> None: """Meant to be used with `gpt.apply(gpt._init_weights)`.""" @@ -83,95 +246,133 @@ def _init_weights(self, module: nn.Module) -> None: def forward( self, idx: torch.Tensor, - input_pos: Optional[torch.Tensor] = None, - input_pos_maxp1: Optional[int] = None, + input_pos: Optional[int] = None, lm_head_chunk_size: int = 0, + skip_lm_head: bool = False, ) -> Union[torch.Tensor, List[torch.Tensor]]: """ - If `input_pos` is provided, the KV cache uses K and V vectors for - positions smaller than entries in `input_pos`. For efficiency, pass - `input_pos_maxp1` as `max(input_pos) + 1` if already available from - your forward algorithm. This slices the KV cache buffers and speeds - up multi-head attention. - - Without `input_pos_maxp1`, the computation uses the full KV cache - (`max_seq_length`) with masking applied. Note that inferring - `input_pos_maxp1` from `input_pos` causes graph breaks and prevents - compilation. + There are two different contexts in which this method is called: + + - Training: `input_pos` not given. KV cache is not needed. + - Inference, `input_pos` is given. There are two cases: `input_pos=0` + (prefill) and `input_pos > 0` (generation). For prefill, KV caches + must have been assigned (:meth:`assign_kv_caches` or + :meth:`set_kv_cache`). We must have + `T <= model.kv_cache_max_prefill_length()`. + - For generation, KV caches must have been assigned + (:meth:`assign_kv_caches` or :meth:`set_kv_cache`). We check that + `input_pos == kv_cache.next_token_pos`. Note that `T > 1` is + permitted here as well. + + Note: If this method is called with `input_pos=0` (prefill) after + generation calls, a new inference sequence is started. The batch + size for the new sequence can be different. + + Token generation (`input_pos > 0`) and `T > 1`: + + This situation is non-standard, since `idx` needs to provide tokens at + positions `input_pos:(input_pos + T)`, whereas the logits are for + generating tokens at `(input_pos + 1):(input_pos + T + 1)`, so only the + last position is needed to generate a new token. Use cases: + - Updating KV caches sequentially if prompt size is larger than max + prefill length of cache + - Speculative decoding. Here, `idx` comes from the cheaper proposal + model, and the logits are needed for the accept/reject probabilities. Args: idx: Token indices of input sequences, shape `(B, T)`, where `B` is batch size. - input_pos: Optional. Positions of input tokens. The default is - `arange(T)`. Can have shape `(T,)` or `(B, T)` (batched index). - input_pos_maxp1: Optional. See above. + input_pos: See above. Defaults to `None` lm_head_chunk_size: Optional. If `lm_head_chunk_size > 0`, the final `lm_head` computation is done in chunks of this size. + skip_lm_head: If `True`, we do not apply the final LM head + `self.lm_head`. Returns: Logit outputs, shape `(B, T, config.padded_vocab_size)`. If `lm_head_chunk_size > 0`, this is a list of chunks of shape `(B, lm_head_chunk_size, config.padded_vocab_size)`, the final entry can be shorter. + If `skip_lm_head` is `True`, we return the final layer outputs, + shape `(B, T, config.n_embd)`. """ + if idx.ndim == 1: + idx = idx.unsqueeze(0) + elif idx.ndim != 2: + raise ValueError(f"idx must be 1D or 2D tensor, but idx.shape = {idx.shape}") T = idx.size(1) if self.max_seq_length < T: raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") - - if input_pos is not None: # use the kv cache - if input_pos.dim() > 2: - # otherwise, things go wrong in `apply_rope` - raise ValueError(f"input_pos must have 1 or 2 dimensions, input_pos.shape = {input_pos.shape}") - if input_pos.shape[-1] != T: - raise ValueError(f"input_pos.shape[-1] = {input_pos.shape[-1]} != {T} = idx.shape[1], must be the same") - cos = batched_index_select(self.cos, 0, input_pos) - sin = batched_index_select(self.sin, 0, input_pos) - if input_pos.dim() == 1: - cos = cos.unsqueeze(0) - sin = sin.unsqueeze(0) - if self.mask_cache is None: - raise TypeError("You need to call `gpt.set_kv_cache()`") - mask = batched_index_select(self.mask_cache, 2, input_pos) - if mask.dim() > 4: - # the mask cache has a batch dim of 1 in addition to the one - # we get if input_pos has a batch dimension - mask = mask.view(*(mask.shape[0:1] + mask.shape[2:])) - if input_pos_maxp1 is not None: - # Shorten final dimension so it just covers all `input_pos` entries - if input_pos_maxp1 > self.max_seq_length: - raise ValueError(f"Positions in 'input_pos' must be in [0,{self.max_seq_length})") - mask = mask[..., :input_pos_maxp1] + for_prefill = False + if input_pos is not None: + # Few tokens generation. This needs a KV cache. If none is assigned, + # the call fails + if not self.are_kv_caches_assigned(): + raise ValueError( + "KV caches are not assigned. Assign KV caches with 'assign_kv_caches' or create default caches with 'set_kv_cache'" + ) + for_prefill = input_pos == 0 + if not for_prefill: + for block_idx, block in enumerate(self.transformer.h): + kv_cache = block.attn.kv_cache + if kv_cache.next_token_pos is None: + raise ValueError("Inference calls need to start with pre-fill, i.e. 'input_pos=0'") + if kv_cache.next_token_pos != input_pos: + raise ValueError( + f"KV cache for layer {block_idx}: input_pos = {input_pos} != {kv_cache.next_token_pos} = kv_cache.next_token_pos" + ) + if kv_cache.max_tokens_forward < T: + raise ValueError( + f"KV cache for layer {block_idx}: T = {T}, must be <= max_tokens_forward = {kv_cache.max_tokens_forward}" + ) + + if self.config.rope_n_elem > 0: + input_pos_array = torch.arange(input_pos, input_pos + T, device=self.cos.device, dtype=torch.int64) + cos = batched_index_select(self.cos, 0, input_pos_array).unsqueeze(0) + sin = batched_index_select(self.sin, 0, input_pos_array).unsqueeze(0) + else: + cos = sin = None else: - # unsqueeze to have a batch dimension + # Unsqueeze to have a batch dimension cos = self.cos[:T].unsqueeze(0) sin = self.sin[:T].unsqueeze(0) - # `cos`, `sin` have shape (1, T, config.rope_n_elem) - mask = None # defaults to causal mask - input_pos_maxp1 = None + # `cos`, `sin` have shape `(1, T, config.rope_n_elem)`, or shape + # `(1, T, config.rope_n_elem, 2)` x = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd) if self.config.scale_embeddings: x = x * torch.tensor(self.config.n_embd**0.5, dtype=x.dtype) + hook = self._start_of_layer_hook for block_idx, block in enumerate(self.transformer.h): + if for_prefill: + # Complain if batch size of cache is too small + eff_batch_size = x.shape[0] + attn = block.attn + if attn.kv_cache.batch_size < eff_batch_size: + raise ValueError( + f"Batch size {eff_batch_size} is too large for KV cache layer {block_idx} (batch size {attn.kv_cache.batch_size}). Use 'assign_kv_caches' or `set_kv_cache'" + ) + if hook is not None: + # Call start of layer hook, passing detached layer input + hook(x.detach(), block_idx, input_pos) if self.config.rope_indices is not None: - x = block( - x, - cos[..., self.config.rope_indices[block_idx]], - sin[..., self.config.rope_indices[block_idx]], - mask, - input_pos, - input_pos_maxp1, - ) + # Select global (0) or local (1) variant + _cos = cos[..., self.config.rope_indices[block_idx]] + _sin = sin[..., self.config.rope_indices[block_idx]] else: - x = block(x, cos, sin, mask, input_pos, input_pos_maxp1) + _cos = cos + _sin = sin + x = block(x, _cos, _sin, idx, self.mha, input_pos) + + if hook is not None: + # Hook is also called for the input to the head block + hook(x.detach(), self.config.n_layer, input_pos) x = self.transformer.ln_f(x) - clamp_head = ( - partial(do_softcapping, thresh=self.config.final_logit_softcapping) - if self.config.final_logit_softcapping is not None - else nn.Identity() - ) + if skip_lm_head: + return x + clamp_head = partial(do_softcapping, thresh=self.config.final_logit_softcapping) if lm_head_chunk_size > 0: # chunk the lm head logits to reduce the peak memory used by autograd return [clamp_head(self.lm_head(x_i)) for x_i in x.split(lm_head_chunk_size, dim=1)] @@ -182,10 +383,23 @@ def forward( def from_name(cls, name: str, **kwargs: Any) -> Self: return cls(Config.from_name(name, **kwargs)) - def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tensor, torch.Tensor]: + def rope_cache( + self, + device: Optional[torch.device] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Recomputes the RoPE cache, consisting of tensors `cos`, `sin`. + + Args: + device: Device for RoPE cache tensors + + Returns: + `(cos, sin)`, each of shape `(max_seq_length, config.rope_n_elem)` + or of shape `(max_seq_length, config.rope_n_elem, 2)`. + + """ if self.config.rope_adjustments is None: extra_config = None - else: adjusted_params_required = ["factor", "low_freq_factor", "high_freq_factor", "original_max_seq_len"] params_present = [param in self.config.rope_adjustments for param in adjusted_params_required] @@ -220,42 +434,59 @@ def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tenso rope_local_base_freq=self.config.rope_local_base_freq, ) - def set_kv_cache( - self, - batch_size: int, - max_seq_length: Optional[int] = None, - rope_cache_length: Optional[int] = None, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ) -> None: - if rope_cache_length is None: - if len(self.cos.shape) == 2: - rope_cache_length = self.cos.size(-1) - else: - rope_cache_length = self.cos[..., 0].size(-1) + def clear_kv_cache(self) -> None: + """ + Note that KV cache objects are removed only if they have not been + assigned with :meth:`assign_kv_caches`. - if max_seq_length is None: - max_seq_length = self.max_seq_length + """ + if self._default_kv_cache: + for block in self.transformer.h: + block.attn.kv_cache = None + self._default_kv_cache = False - # initialize the kv cache for all blocks - for block in self.transformer.h: - block.attn.kv_cache = block.attn.build_kv_cache( - batch_size, - max_seq_length, - rope_cache_length, - device, - dtype, - ) + def get_kv_cache_params(self) -> Optional[KVCacheParams]: + """ + Assigned KV caches need not have the same `batch_size` and + `cache_length` values. For these, the returned :class:`KVCacheParams` + object contains the minimum values over all caches. - if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length: - # passing `attn_mask` to SDPA disables the flash implementation. since we only need the mask - # for the kv-cache support (only during inference), we only create it in that situation - self.mask_cache = build_mask_cache(max_seq_length, device) + Returns: + Parameters for KV caches (see above), or `None` if KV caches are + not assigned. - def clear_kv_cache(self) -> None: - self.mask_cache = None - for block in self.transformer.h: - block.attn.kv_cache = None + """ + kv_cache = self.transformer.h[0].attn.kv_cache + if kv_cache is None: + return None + # Caches need not have the same cache_length, batch_size + caches = [block.attn.kv_cache for block in self.transformer.h] + batch_size = min(c.batch_size for c in caches) + cache_length = min(c.cache_length for c in caches) + params = replace( + kv_cache.get_params(), + batch_size=batch_size, + cache_length=cache_length, + ) + return params + + def kv_cache_max_tokens_forward(self) -> Optional[int]: + caches = [layer.attn.kv_cache for layer in self.transformer.h] + if any(cache is None for cache in caches): + return None + else: + return min(cache.max_tokens_forward for cache in caches) + + def kv_cache_max_prefill_length(self) -> Optional[int]: + caches = [layer.attn.kv_cache for layer in self.transformer.h] + if any(cache is None for cache in caches): + return None + else: + mlps = [kvc.max_prefill_length for kvc in caches] + if all(mlp is None for mlp in mlps): + return None + else: + return min(mlp for mlp in mlps if mlp is not None) class Block(nn.Module): @@ -263,6 +494,7 @@ def __init__( self, config: Config, block_idx: int, + kv_cache: Optional[KVCache] = None, ) -> None: super().__init__() if not config.parallel_residual and config.shared_attention_norm: @@ -272,7 +504,7 @@ def __init__( ) self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) - self.attn = CausalSelfAttention(config, block_idx) + self.attn = CausalSelfAttention(config, block_idx, kv_cache=kv_cache) self.post_attention_norm = ( config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_attention_norm else nn.Identity() ) @@ -281,7 +513,6 @@ def __init__( self.post_mlp_norm = ( config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_mlp_norm else nn.Identity() ) - self.config = config def forward( @@ -289,9 +520,9 @@ def forward( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, - mask: Optional[torch.Tensor] = None, - input_pos: Optional[torch.Tensor] = None, - input_pos_maxp1: Optional[int] = None, + token_idx: torch.Tensor, + mha: MultiHeadSelfAttention, + input_pos: Optional[int] = None, ) -> torch.Tensor: """ Non-parallel residual Parallel residual @@ -315,7 +546,14 @@ def forward( """ x_normed = self.norm_1(x) - attention_output = self.attn(x_normed, cos, sin, mask, input_pos, input_pos_maxp1) + attention_output = self.attn( + x_normed, + cos=cos, + sin=sin, + token_idx=token_idx, + mha=mha, + input_pos=input_pos, + ) attention_output = self.post_attention_norm(attention_output) if self.config.parallel_residual: @@ -329,7 +567,12 @@ def forward( class CausalSelfAttention(nn.Module): - def __init__(self, config: Config, block_idx: int) -> None: + def __init__( + self, + config: Config, + block_idx: int, + kv_cache: Optional[KVCache] = None, + ) -> None: super().__init__() # key, query and value projections for all heads, but in a batch self.qkv = nn.Linear( @@ -339,11 +582,8 @@ def __init__(self, config: Config, block_idx: int) -> None: ) # output projection self.proj = nn.Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) - # disabled by default - self.kv_cache: Optional[KVCache] = None - self.apply_sliding_window_attention = False - if config.sliding_window_size is not None and config.sliding_window_indices is not None: - self.apply_sliding_window_attention = config.sliding_window_indices[block_idx] + # KV cache (needed for inference) + self.kv_cache = kv_cache if config.norm_qk: self.norm_q = config.norm_class(config.head_size, eps=config.norm_eps) @@ -359,10 +599,22 @@ def forward( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, - mask: Optional[torch.Tensor] = None, - input_pos: Optional[torch.Tensor] = None, - input_pos_maxp1: Optional[int] = None, + token_idx: torch.Tensor, + mha: MultiHeadSelfAttention, + input_pos: Optional[int] = None, ) -> torch.Tensor: + """ + Args: + x: Input tensor + cos: RoPE parameters + sin: RoPE parameters + token_idx: Token indexes corresponding to `x` + mha: Multi-head self-attention code + input_pos: See :meth:`GPT.forward` + + Returns: + Output tensor + """ # Notation: # - B | batch size # - T | time-step (sequence length) @@ -375,6 +627,24 @@ def forward( n_query_groups = self.config.n_query_groups rope_n_elem = self.config.rope_n_elem B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + for_prefill = False + if input_pos is not None: + for_prefill = input_pos == 0 + if self.kv_cache is None: + raise ValueError( + "KV caches are not assigned. Assign KV caches with 'assign_kv_caches' or create default caches with 'set_kv_cache'" + ) + if not for_prefill: + if self.kv_cache.next_token_pos is None: + raise ValueError("Inference calls need to start with pre-fill, i.e. 'input_pos=0'") + if self.kv_cache.next_token_pos != input_pos: + raise ValueError( + f"KV cache: input_pos = {input_pos} != {self.kv_cache.next_token_pos} = kv_cache.next_token_pos" + ) + if self.kv_cache.max_tokens_forward < T: + raise ValueError( + f"KV cache: T = {T}, must be <= max_tokens_forward = {self.kv_cache.max_tokens_forward}" + ) # Perform a single multiplication operation using a combined QKV matrix to calculate `query`, `key`, and `value` # instead of individually multiplying the input `x` with the respective weight matrices. @@ -391,117 +661,74 @@ def forward( # embedding size (C) into num_heads (nh) and head_size (hs). q = q.view(B, T, n_head, head_size) # (B, T, nh_q, hs) k = k.view(B, T, n_query_groups, head_size) # (B, T, nh_k, hs) - v = v.view(B, T, n_query_groups, head_size) # (B, T, nh_v, hs) + v = v.view(B, T, n_query_groups, head_size) # (B, T, nh_k, hs) # The tensors `query`, `key`, and `value` are now accurately structured: within each batch element (B), there are - # multiple heads (nh), and within each head, there is a sequence of elements (T), each represented by a vector + # multiple heads (nh_q), and within each head, there is a sequence of elements (T), each represented by a vector # of size `hs`. + # Note that `nh_k` can be smaller than `nh_q` (but the latter must be a + # multiple of the former). This works with the + # `scaled_dot_product_attention` implementations below. q = q.transpose(1, 2) # (B, nh_q, T, hs) k = k.transpose(1, 2) # (B, nh_k, T, hs) - v = v.transpose(1, 2) # (B, nh_v, T, hs) + v = v.transpose(1, 2) # (B, nh_k, T, hs) if self.config.norm_qk: q = self.norm_q(q) k = self.norm_k(k) # Unlike standard positional embeddings rotary embeddings must be applied at every layer. - q_roped = apply_rope(q[..., :rope_n_elem], cos, sin) - k_roped = apply_rope(k[..., :rope_n_elem], cos, sin) - q = torch.cat((q_roped, q[..., rope_n_elem:]), dim=-1) # (B, nh_q, T, hs) - k = torch.cat((k_roped, k[..., rope_n_elem:]), dim=-1) # (B, nh_k, T, hs) + if rope_n_elem > 0: + q_roped = apply_rope(q[..., :rope_n_elem], cos, sin) + k_roped = apply_rope(k[..., :rope_n_elem], cos, sin) + q = torch.cat((q_roped, q[..., rope_n_elem:]), dim=-1) # (B, nh_q, T, hs) + k = torch.cat((k_roped, k[..., rope_n_elem:]), dim=-1) # (B, nh_k, T, hs) + + # Inner part of multi-head self-attention computation + if input_pos is None: + # Default causal self-attention + y, _ = mha( + query=q, + k_and_v=DefaultKeysAndValues(k, v), + block_idx=self.block_idx, + ) + else: + # Defer this to KV cache + y = self.kv_cache( + query=q, + key=k, + value=v, + token_idx=token_idx, + input_pos=input_pos, + ) - # Apply kv-cache during inference. - if input_pos is not None: - if not isinstance(self.kv_cache, KVCache): - raise TypeError("You need to call `gpt.set_kv_cache()`") - k, v = self.kv_cache(input_pos, k, v) - if input_pos_maxp1 is not None: - # Subselect along sequence dimension - k = k[..., :input_pos_maxp1, :] - v = v[..., :input_pos_maxp1, :] - # k, v: (B, nh_k, input_pos_maxp1, hs) - # If input_pos_maxp1 is None -> max_seq_length - - # Grouped queries: balance the number of heads across all three matrices. - # NOTE: flash attention requires it in training mode. - # Multi-query: this step can be skipped since there is only 1 head, allowing us to use broadcasting. - if n_query_groups != n_head and (input_pos is None or n_query_groups != 1): - q_per_kv = n_head // n_query_groups - k = k.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) - v = v.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) - - if self.apply_sliding_window_attention: - """ - Global Window Sliding window Sliding window - attention mask + bias = attention mask - ┌────────────────────────┐ ┌───────────────────────┐ ┌─────────────────────────┐ - │ True False False False │ │ True True True True │ │ True False False False │ - │ True True False False │ │ True True True True │ │ True True False False │ - │ True True True False │ │ False True True True │ │ False True True False │ - │ True True True True │ │ False False True True │ │ False False True True │ - └────────────────────────┘ └───────────────────────┘ └─────────────────────────┘ - """ - if mask is None: - mask = torch.ones(T, T, dtype=q.dtype, device=q.device).triu(diagonal=1) - mask.masked_fill_(mask.bool(), float("-inf")) - mask = mask.view(1, 1, *mask.shape) - sliding_window_bias = torch.ones_like(mask).tril(diagonal=-self.config.sliding_window_size) - sliding_window_bias.masked_fill_(sliding_window_bias.bool(), float("-inf")) - mask += sliding_window_bias - - # Efficient attention using Flash Attention CUDA kernels. - # NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled. - # ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs) - y = self.scaled_dot_product_attention(q, k, v, mask) - - # Re-assemble all head outputs side by side. - y = y.reshape(B, T, head_size * n_head) - - # Output projection. + # Output projection + y = self._transform_output(y, query=q, mha=mha) return self.proj(y) # (B, T, C) - def scaled_dot_product_attention( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None + def _transform_output( + self, + y: torch.Tensor, + query: torch.Tensor, + mha: MultiHeadSelfAttention, ) -> torch.Tensor: - scale = 1.0 / math.sqrt(self.config.attention_scores_scalar or self.config.head_size) - - # with softcapping we cannot use SDPA - if self.config.attention_logit_softcapping is not None: - scores = q @ k.mT * scale - scores = do_softcapping(scores, self.config.attention_logit_softcapping) - if mask is None: - mask = torch.ones(q.size(2), q.size(2), dtype=q.dtype, device=q.device).triu(diagonal=1) - mask.masked_fill_(mask.bool(), torch.finfo(q.dtype).min) - scores = scores + mask - scores = F.softmax(scores, dim=-1, dtype=torch.float).to(dtype=q.dtype) - y = scores @ v - else: - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None - ) - return y.transpose(1, 2) + return y - def build_kv_cache( + def create_default_kv_cache( self, batch_size: int, - max_seq_length: int, - rope_cache_length: Optional[int] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - ) -> "KVCache": - v_shape = (batch_size, self.config.n_query_groups, max_seq_length, self.config.head_size) - if rope_cache_length is None: - if self.config.rotary_percentage != 1.0: - raise TypeError("Please pass the `rope_cache_length=gpt.cos.size(-1)` value") - k_shape = v_shape - else: - k_shape = ( - batch_size, - self.config.n_query_groups, - max_seq_length, - rope_cache_length + self.config.head_size - self.config.rope_n_elem, - ) - return KVCache(k_shape, v_shape, device=device, dtype=dtype) + max_sequence_length: Optional[int] = None, + ): + self.kv_cache = DenseKVCache( + config=self.config, + batch_size=batch_size, + block_idx=self.block_idx, + device=device, + dtype=dtype, + max_sequence_length=max_sequence_length, + ) def _load_from_state_dict(self, state_dict: dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with legacy checkpoints.""" @@ -599,11 +826,17 @@ def build_rope_cache( device (torch.device, optional): Device for tensor allocations. base (int, optional): Base for computing inverse frequencies. condense_ratio (int, optional): Ratio to condense the position indices. - extra_config (dict, optional): Configuration parameters for frequency adjustments (used by Llama 3.1 and 3.2) + extra_config (dict, optional): Configuration parameters for + frequency adjustments (used by Llama 3.1 and 3.2) + rope_local_base_freq: If given, this is an alternative value for + `base`. In this case, the returned tensors have an extra dimension. Returns: Tuple[torch.Tensor, torch.Tensor]: Cosine and sine caches for RoPE. - Shapes are `(seq_len, n_elem)`. + Shapes are `(seq_len, n_elem)` if `rope_local_base_freq` is not + given, otherwise `(seq_len, n_elem, 2)`, so that `[..., 0]` is for + `base`, and `[..., 1]` for `rope_local_base_freq`. + """ # Compute the inverse frequencies theta @@ -634,11 +867,13 @@ def build_rope_cache( idx_theta = torch.outer(seq_idx, theta).repeat(1, 2) # If `n_elem` is odd, the final dimension of `idx_theta` has size # `n_elem + 1`, so need to cut something off. + # Due to a current bug in Hugging Face, in the case `n_elem == 1`, we leave # `idx_theta`, `cos`, `sin` as is. Things work out in `apply_rope` due to # broadcasting. If we shorten `idx_theta`, unit tests comparing to # Hugging Face fail. # https://github.com/huggingface/transformers/issues/35233 + # TODO: Remove `> 1` once HF bug is fixed! if idx_theta.shape[-1] > n_elem > 1: idx_theta = idx_theta[..., :n_elem] @@ -648,6 +883,7 @@ def build_rope_cache( local_theta = 1.0 / (rope_local_base_freq ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) local_idx_theta = torch.outer(seq_idx, local_theta) local_idx_theta = local_idx_theta.repeat(1, 2) + # TODO: Remove `> 1` once HF bug is fixed! if local_idx_theta.shape[-1] > n_elem > 1: local_idx_theta = local_idx_theta[..., :n_elem] @@ -656,75 +892,6 @@ def build_rope_cache( return torch.cos(idx_theta), torch.sin(idx_theta) -def batched_index_select(t, dim, idx): - """index_select for batched index and unbatched t""" - if idx.dim() == 1: - return torch.index_select(t, dim, idx) - - *batch_shape, idx_size = idx.shape - res = torch.index_select(t, dim, idx.reshape(-1)) # flat index - # split out single batch idx - res = res.view(*t.shape[:dim], -1, idx_size, *t.shape[dim + 1 :]) - if dim > 0: - # move batch dim to front, this is np.rollaxis(res, dim, 0) for tensors - dims = [dim] + list(range(res.dim())) - del dims[dim + 1] - res = res.permute(dims) - # unflatten batch dims - res = res.view(*batch_shape, *res.shape[1:]) - return res - - -def batched_index_copy_(t, dim, idx, val): - """Index copy for batched t, idx, val""" - - if t.device.type == "mps": - # Normalize negative dimensions - if dim < 0: - dim = t.dim() + dim - if idx.dim() == 1: - idx_shape = [1] * val.dim() - idx_shape[dim] = -1 - idx_expanded = idx.view(*idx_shape) - idx_expanded = idx_expanded.expand_as(val) - t.scatter_(dim, idx_expanded, val) - return t - - elif idx.dim() == 2: - assert dim != 0, "Cannot index the batch dimension" - batch_size = idx.size(0) - idx_size = idx.size(1) - assert batch_size == t.size(0) == val.size(0) - - idx_shape = [batch_size] + [1] * (val.dim() - 1) - idx_shape[dim] = idx_size - idx_expanded = idx.view(*idx_shape) - idx_expanded = idx_expanded.expand_as(val) - - t.scatter_(dim, idx_expanded, val) - return t - else: - raise NotImplementedError(f"idx.dim() == {idx.dim()} not supported") - - else: - if idx.dim() == 1: - return t.index_copy_(dim, idx, val) - - assert idx.dim() == 2, f"multiple batch dims not yet {idx.shape=}" - assert dim != 0, f"cannot index batch dim {dim=}" - batch_size, idx_size = idx.shape - assert batch_size == t.size(0) - assert batch_size == val.size(0) - - # if we can view the batch and indexed dimensions together, we could - # do index trickery. This is, sadly, not the case for kvcache so we - # fall back to for loop - for i in range(batch_size): - unbatched_dim = dim if dim < 0 else dim - 1 - t[i].index_copy_(unbatched_dim, idx[i], val[i]) - return t - - def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: """ Applies RoPE transform to `x`. Note that `cos`, `sin` need to have a batch @@ -738,7 +905,7 @@ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.T Returns: Encoded tensor, `(B, ..., T, head_size)` """ - if cos.dim() != 3: + if cos.ndim != 3: raise ValueError(f"cos must be three-dimensional, but shape is {cos.shape}") if cos.shape != sin.shape: raise ValueError(f"cos, sin must have same shape, but cos.shape={cos.shape}, sin.shape={sin.shape}") @@ -746,7 +913,7 @@ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.T x1 = x[..., :head_size_half] # (B, ..., T, head_size/2) x2 = x[..., head_size_half:] # (B, ..., T, head_size/2) rotated = torch.cat((-x2, x1), dim=-1) # (B, ..., T, head_size) - dims_diff = x.dim() - cos.dim() + dims_diff = x.ndim - cos.ndim if dims_diff > 0: # Ensure that shapes of `x`, `cos`, `sin` align new_shape = cos.shape[0:1] + (1,) * dims_diff + cos.shape[1:] @@ -757,64 +924,6 @@ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.T return roped.to(dtype=x.dtype) -def do_softcapping(x: torch.Tensor, thresh: float) -> torch.Tensor: - return torch.tanh(x / thresh) * thresh - - -class KVCache(nn.Module): - """ - Buffers `k`, `v` have shape - `(batch_size, n_query_groups, max_seq_length, head_size)`. - """ - - def __init__( - self, - k_shape: Tuple[int, int, int, int], - v_shape: Tuple[int, int, int, int], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ) -> None: - super().__init__() - self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False) - self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False) - - def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Writes new values `k` and `v` into the cache at the positions specified - by `input_pos` along the sequence dimension (`max_seq_length`). The batch - size of `k` and `v` (`bs`) must be smaller or equal to `KVCache` batch - size. Returns the full buffers, adjusted to the batch size `bs`. - - Args: - input_pos: Position index, `(bs, T)` or `(T,)` - k: New values, `(bs, n_query_groups, T, head_size)` - v: New values, `(bs, n_query_groups, T, head_size)` - - Returns: - k_full, v_full, `(bs, n_query_groups, max_seq_length, head_size)` - - """ - # move the buffer to the activation dtype for when AMP is used - if self.k.dtype != k.dtype: - self.k = self.k.to(k.dtype) - if self.v.dtype != v.dtype: - self.v = self.v.to(v.dtype) - # update the cache - bs = k.size(0) - k = batched_index_copy_(self.k[:bs, ...], -2, input_pos, k) - v = batched_index_copy_(self.v[:bs, ...], -2, input_pos, v) - return k, v - - def reset_parameters(self) -> None: - torch.nn.init.zeros_(self.k) - torch.nn.init.zeros_(self.v) - - -def build_mask_cache(max_seq_length: int, device: Optional[torch.device] = None) -> torch.Tensor: - ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool) - return torch.tril(ones).unsqueeze(0).unsqueeze(0) - - class RMSNorm(torch.nn.Module): """Root Mean Square Layer Normalization. diff --git a/litgpt/utils.py b/litgpt/utils.py index af97fa2f11..ba76d4d310 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -873,3 +873,22 @@ def kill_process_tree(pid: int): parent.kill() except psutil.NoSuchProcess: pass # Process already exited + + +def batched_index_select(t: torch.Tensor, dim: int, idx: torch.Tensor) -> torch.Tensor: + """index_select for batched index and unbatched t""" + if idx.ndim == 1: + return torch.index_select(t, dim, idx) + + *batch_shape, idx_size = idx.shape + res = torch.index_select(t, dim, idx.reshape(-1)) # flat index + # split out single batch idx + res = res.view(*t.shape[:dim], -1, idx_size, *t.shape[dim + 1 :]) + if dim > 0: + # move batch dim to front, this is np.rollaxis(res, dim, 0) for tensors + dims = [dim] + list(range(res.ndim)) + del dims[dim + 1] + res = res.permute(dims) + # unflatten batch dims + res = res.view(*batch_shape, *res.shape[1:]) + return res diff --git a/tests/generate/test_adapter.py b/tests/generate/test_adapter.py index 782d3d435c..af50efbb75 100644 --- a/tests/generate/test_adapter.py +++ b/tests/generate/test_adapter.py @@ -35,6 +35,7 @@ def test_main(fake_checkpoint_dir, monkeypatch, version, tensor_like): tokenizer_mock = Mock() tokenizer_mock.return_value.encode.return_value = torch.tensor([[1, 2, 3]]) tokenizer_mock.return_value.decode.return_value = "### Response:foo bar baz" + tokenizer_mock.return_value.eos_id.return_value = 255 # TODO (does not work) monkeypatch.setattr(generate, "Tokenizer", tokenizer_mock) generate_mock = Mock() generate_mock.return_value = torch.tensor([[3, 2, 1]]) diff --git a/tests/generate/test_main.py b/tests/generate/test_main.py index fd430318b0..6328e1930c 100644 --- a/tests/generate/test_main.py +++ b/tests/generate/test_main.py @@ -15,7 +15,7 @@ import litgpt.generate.base as generate from litgpt import GPT, Config -from litgpt.generate.base import sample +from litgpt.generate.base import batched_sample skip_in_ci_on_macos = pytest.mark.skipif( sys.platform == "darwin" and os.getenv("GITHUB_ACTIONS") == "true", @@ -23,10 +23,7 @@ ) -@pytest.mark.parametrize( - "max_seq_length", (pytest.param(10, marks=pytest.mark.xfail(raises=NotImplementedError, strict=True)), 20 + 5) -) -def test_generate(max_seq_length): +def test_generate(): import lightning as L L.seed_everything(1234) @@ -34,25 +31,41 @@ def test_generate(max_seq_length): T = 5 input_idx = torch.arange(0, T) - config = Config(block_size=128, vocab_size=16, n_layer=1, n_head=4, n_embd=8) + config = Config( + block_size=128, + vocab_size=16, + n_layer=1, + n_head=4, + n_embd=8, + ) model = GPT(config) - model.max_seq_length = max_seq_length - model.set_kv_cache(batch_size=1) max_new_tokens = 20 + model.max_seq_length = T + max_new_tokens + model.set_kv_cache(batch_size=1) multinomial_results = [] def multinomial(*args, **kwargs): - out = torch.multinomial(*args, **kwargs, num_samples=1) + if args: + probs = args[0] + else: + probs = kwargs.get("probs") + out = torch.multinomial(probs, num_samples=1) multinomial_results.append(out) return out with mock.patch("litgpt.generate.base.multinomial_num_samples_1", multinomial): - out = generate.generate(model, input_idx, T + max_new_tokens, top_k=1) + out = generate.generate( + model=model, + prompt=input_idx, + max_returned_tokens=T + max_new_tokens, + top_k=1, + ) assert out.size(0) == T + max_new_tokens, (out.size(0), T + max_new_tokens) multinomial_results = torch.hstack(multinomial_results) - expected = torch.cat((input_idx, multinomial_results)) + print(f"input_idx {input_idx.shape}, multinomial_results: {multinomial_results.shape}") + expected = torch.cat((input_idx, multinomial_results.squeeze(0))) assert out.shape == expected.shape, (out.shape, expected.shape) torch.testing.assert_close(out, expected) @@ -60,32 +73,108 @@ def multinomial(*args, **kwargs): @skip_in_ci_on_macos def test_main(fake_checkpoint_dir, monkeypatch, tensor_like): config_path = fake_checkpoint_dir / "model_config.yaml" - config = {"block_size": 128, "vocab_size": 50, "n_layer": 2, "n_head": 4, "n_embd": 8, "rotary_percentage": 1} + config = { + "block_size": 128, + "vocab_size": 50, + "n_layer": 2, + "n_head": 4, + "n_embd": 8, + "rotary_percentage": 1, + } config_path.write_text(yaml.dump(config)) module_mock = Mock() - module_mock.config.block_size = 128 + module_mock.config.block_size = config["block_size"] load_mock = Mock() load_mock.return_value = load_mock monkeypatch.setattr(generate, "load_checkpoint", load_mock) tokenizer_mock = Mock() tokenizer_mock.return_value.encode.return_value = torch.tensor([1, 2, 3]) tokenizer_mock.return_value.decode.return_value = "foo bar baz" + tokenizer_mock.return_value.eos_id.return_value = 255 # TODO (does not work) monkeypatch.setattr(generate, "Tokenizer", tokenizer_mock) generate_mock = Mock() - generate_mock.return_value = torch.tensor([3, 2, 1]) + generate_mock.return_value = torch.tensor( + [ + 1, + 2, + 3, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 0, + ] + ) + len_return_value = generate_mock.return_value.numel() monkeypatch.setattr(generate, "generate", generate_mock) num_samples = 2 out, err = StringIO(), StringIO() + sample_kwargs = dict( + temperature=2.0, + top_k=2, + top_p=0.9, + ) with redirect_stdout(out), redirect_stderr(err): - generate.main(temperature=2.0, top_k=2, top_p=0.9, num_samples=num_samples, checkpoint_dir=fake_checkpoint_dir) + generate.main( + **sample_kwargs, + num_samples=num_samples, + checkpoint_dir=fake_checkpoint_dir, + ) assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples - assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value) + assert torch.allclose( + tokenizer_mock.return_value.decode.call_args[0][0].to(torch.device("cpu")), generate_mock.return_value + ) assert ( generate_mock.mock_calls - == [call(ANY, tensor_like, 53, temperature=2.0, top_k=2, top_p=0.9, eos_id=tokenizer_mock.return_value.eos_id)] + == [call(ANY, tensor_like, len_return_value, **sample_kwargs, eos_id=tokenizer_mock.return_value.eos_id)] * num_samples ) expected_output = "foo bar baz\n" * num_samples @@ -119,15 +208,24 @@ def test_sample(temperature): ], dtype=torch.float32, ) - token = sample(logits, temperature=temperature, top_p=0.8) + # Note: Both `sample` and `batched_sample` create only 1 sample, not 3. + # It is like passing `logits[:, 1-:, :]` + token = batched_sample(logits, kwargs=dict(temperature=temperature, top_p=0.8)) - assert token.shape == (1,) + assert token.shape == (2, 1) # sample is batch size 1 only for now - this should be [0, 1] once batched generation is supported - assert token.tolist() == [0] + assert token[0, -1].item() == 0 def test_generate_different_results_with_different_top_p(): - config = Config(block_size=128, vocab_size=16, n_layer=1, n_head=4, n_embd=8) + config = Config( + block_size=128, + vocab_size=16, + n_layer=1, + n_head=4, + n_embd=8, + rotary_percentage=1, + ) model = GPT(config) model.max_seq_length = 50 model.set_kv_cache(batch_size=1) @@ -136,8 +234,18 @@ def test_generate_different_results_with_different_top_p(): input_idx = torch.randint(10, size=(1,)) torch.manual_seed(123) - output1 = generate.generate(model, input_idx, 20, top_p=1.0) + output1 = generate.generate( + model=model, + prompt=input_idx, + max_returned_tokens=20, + top_p=1.0, + ) torch.manual_seed(123) - output2 = generate.generate(model, input_idx, 20, top_p=0.1) + output2 = generate.generate( + model=model, + prompt=input_idx, + max_returned_tokens=20, + top_p=0.1, + ) assert not torch.equal(output1, output2) diff --git a/tests/generate/test_sequentially.py b/tests/generate/test_sequentially.py index 37175fa489..10dae3ff76 100644 --- a/tests/generate/test_sequentially.py +++ b/tests/generate/test_sequentially.py @@ -111,6 +111,7 @@ def _test_model_1device(accelerator): fabric = Fabric(accelerator=accelerator, devices=1) with torch.device("meta"): model = GPT.from_name("pythia-14m", n_layer=2) + model.set_kv_cache(1) model = sequential(model, fabric.device, 15, 1) device_str = str(fabric.device) diff --git a/tests/kvcache/test_base.py b/tests/kvcache/test_base.py new file mode 100644 index 0000000000..f75d031ded --- /dev/null +++ b/tests/kvcache/test_base.py @@ -0,0 +1,67 @@ +import random + +import torch + +from litgpt.kvcache.base import KVCacheParams +from litgpt.kvcache.test_utils import ( + create_kv_cache, + random_keys_values, + random_tensor, + tensor_is_simple, +) + + +def test_most_recent(): + seed = 31415927 + random.seed(seed) + torch.random.manual_seed(seed) + vocab_size = 128 + + params = KVCacheParams( + batch_size=3, + n_query_groups=4, + cache_length=32, + head_size=8, + n_head=4, + device=torch.device("cpu"), + dtype=torch.bfloat16, + ) + cache_length = params.cache_length + kv_cache = create_kv_cache("lastrec-default", params) + num_insert = random.randint(cache_length, 3 * cache_length) + max_prefill_length = kv_cache.max_prefill_length + num_prefill = random.randint(num_insert // 3, int(num_insert * 0.75)) + if max_prefill_length is not None and num_prefill > max_prefill_length: + num_prefill = max_prefill_length + + keys, values = random_keys_values(params, num=num_insert) + queries = random_tensor(params, num=num_insert) + token_idx = torch.randint( + low=0, + high=vocab_size, + size=(params.batch_size, num_insert), + ) + kv_cache( + query=queries[:, :, :num_prefill, :], + key=keys[:, :, :num_prefill, :], + value=values[:, :, :num_prefill, :], + token_idx=token_idx[:, :num_prefill], + input_pos=0, + ) + for pos in range(num_prefill, num_insert): + kv_cache( + query=queries[:, :, pos : (pos + 1), :], + key=keys[:, :, pos : (pos + 1), :], + value=values[:, :, pos : (pos + 1), :], + token_idx=token_idx[:, pos : (pos + 1)], + input_pos=pos, + ) + + current_length = min(cache_length, num_insert) + assert kv_cache.current_length == current_length + token_positions = kv_cache.token_positions().to(dtype=torch.int64) + assert token_positions.shape == (params.batch_size, params.n_query_groups, current_length) + assert tensor_is_simple(token_positions) + positions = token_positions[0, 0, :].tolist() + assert len(set(positions)) == current_length + assert all(num_insert - current_length <= x < num_insert for x in positions) diff --git a/tests/kvcache/test_generic.py b/tests/kvcache/test_generic.py new file mode 100644 index 0000000000..affa507c97 --- /dev/null +++ b/tests/kvcache/test_generic.py @@ -0,0 +1,152 @@ +import random + +import pytest +import torch + +from litgpt.kvcache.base import KVCacheParams +from litgpt.kvcache.test_utils import ( + KV_CACHE_NAMES, + create_kv_cache, + random_keys_values, + random_tensor, + tensor_is_simple, +) + + +@pytest.mark.parametrize("name", KV_CACHE_NAMES) +def test_store_retrieve(name): + seed = 31415927 + random.seed(seed) + torch.random.manual_seed(seed) + vocab_size = 128 + + params = KVCacheParams( + batch_size=3, + n_query_groups=4, + cache_length=32, + head_size=8, + n_head=4, + device=torch.device("cpu"), + dtype=torch.bfloat16, + ) + cache_length = params.cache_length + kv_cache = create_kv_cache(name, params) + if name.startswith("dense"): + num_insert = random.randint(cache_length // 2, cache_length) + else: + num_insert = random.randint(cache_length, 3 * cache_length) + max_prefill_length = kv_cache.max_prefill_length + num_prefill = random.randint(num_insert // 3, int(num_insert * 0.75)) + if max_prefill_length is not None and num_prefill > max_prefill_length: + num_prefill = max_prefill_length + + keys, values = random_keys_values(params, num=num_insert) + queries = random_tensor(params, num=num_insert) + token_idx = torch.randint( + low=0, + high=vocab_size, + size=(params.batch_size, num_insert), + ) + kv_cache( + query=queries[:, :, :num_prefill, :], + key=keys[:, :, :num_prefill, :], + value=values[:, :, :num_prefill, :], + token_idx=token_idx[:, :num_prefill], + input_pos=0, + ) + for pos in range(num_prefill, num_insert): + kv_cache( + query=queries[:, :, pos : (pos + 1), :], + key=keys[:, :, pos : (pos + 1), :], + value=values[:, :, pos : (pos + 1), :], + token_idx=token_idx[:, pos : (pos + 1)], + input_pos=pos, + ) + + current_length = min(cache_length, num_insert) + assert kv_cache.current_length == current_length + token_positions = kv_cache.token_positions().to(dtype=torch.int64) + assert token_positions.shape == (params.batch_size, params.n_query_groups, current_length) + assert tensor_is_simple(token_positions) + # Positions for every (b, h) must be different + for b, h in zip(range(params.batch_size), range(params.n_query_groups)): + token_pos = token_positions[b, h, :].tolist() + assert all(0 <= x < num_insert for x in token_pos) + err_msg = f"num_insert = {num_insert}, b = {b}, h = {h}, current_length = {current_length}, num_prefill = {num_prefill}" + assert len(set(token_pos)) == current_length, err_msg + # Test cache content slice by slice + keys_and_values = kv_cache.get_keys_values() + for pos in range(current_length): + index = token_positions[:, :, pos][:, :, None, None].expand(-1, -1, 1, params.head_size) + # `index[i, j, 0, k] = next_position[i, j]` + k_expected = keys.gather(-2, index).squeeze(-2) + v_expected = values.gather(-2, index).squeeze(-2) + torch.testing.assert_close(k_expected, keys_and_values.keys()[:, :, pos, :]) + torch.testing.assert_close(v_expected, keys_and_values.values()[:, :, pos, :]) + + +@pytest.mark.parametrize("name", KV_CACHE_NAMES) +def test_prefill(name): + seed = 31415927 + random.seed(seed) + torch.random.manual_seed(seed) + vocab_size = 128 + num_compares = 3 + + params = KVCacheParams( + batch_size=2, + n_query_groups=2, + cache_length=32, + head_size=64, + n_head=2, + device=torch.device("cpu"), + dtype=torch.bfloat16, + ) + cache_length = params.cache_length + kv_cache = create_kv_cache(name, params) + + keys, values = random_keys_values(params, num=cache_length) + queries = random_tensor(params, num=cache_length) + token_idx = torch.randint( + low=0, + high=vocab_size, + size=(params.batch_size, cache_length), + ) + keys_cached = [] + values_cached = [] + max_prefill_length = kv_cache.max_prefill_length + for _ in range(num_compares): + num_prefill = random.randint(cache_length // 8, cache_length) + if max_prefill_length is not None and num_prefill > max_prefill_length: + num_prefill = max_prefill_length + kv_cache( + query=queries[:, :, :num_prefill, :], + key=keys[:, :, :num_prefill, :], + value=values[:, :, :num_prefill, :], + token_idx=token_idx[:, :num_prefill], + input_pos=0, + ) + for pos in range(num_prefill, cache_length): + kv_cache( + query=queries[:, :, pos : (pos + 1), :], + key=keys[:, :, pos : (pos + 1), :], + value=values[:, :, pos : (pos + 1), :], + token_idx=token_idx[:, pos : (pos + 1)], + input_pos=pos, + ) + keys_and_values = kv_cache.get_keys_values() + if keys_and_values is not None: + keys_cached.append(keys_and_values.keys().clone()) + values_cached.append(keys_and_values.values().clone()) + else: + keys_cached.append(None) + values_cached.append(None) + + num_none = 0 + for k, v in zip(keys_cached[1:], values_cached[1:]): + if k is not None: + torch.testing.assert_close(k, keys_cached[0]) + torch.testing.assert_close(v, values_cached[0]) + else: + num_none += 1 + assert num_none < num_compares - 1 diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 33f628eda2..e8ba05e739 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -132,8 +132,7 @@ def test_adapter_compile(): model = GPT(model.config) model.set_kv_cache(2) - input_pos = torch.arange(model.config.block_size) - explanation = torch._dynamo.explain(model)(x, input_pos) + explanation = torch._dynamo.explain(model)(x, input_pos=0) assert isinstance(explanation, debugging.ExplainOutput) assert explanation.graph_count == 1 assert explanation.graph_break_count == 0 diff --git a/tests/test_adapter_v2.py b/tests/test_adapter_v2.py index 1e9837fc53..465b4e277e 100644 --- a/tests/test_adapter_v2.py +++ b/tests/test_adapter_v2.py @@ -161,8 +161,7 @@ def test_adapter_v2_compile(): model = AdapterV2GPT(model.config) model.set_kv_cache(2) - input_pos = torch.arange(model.config.block_size) - explanation = torch._dynamo.explain(model)(x, input_pos) + explanation = torch._dynamo.explain(model)(x, input_pos=0) assert isinstance(explanation, debugging.ExplainOutput) assert explanation.graph_count == 1 assert explanation.graph_break_count == 0 diff --git a/tests/test_batch.py b/tests/test_batch.py index 32eb1c2f3a..860b75f5a8 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -31,7 +31,11 @@ def create_llm(tmp_path, batch_size, max_seq_length, device) -> tuple[LLM, GPT]: init="random", ) model: GPT = llm.model - model.set_kv_cache(batch_size=batch_size, max_seq_length=max_seq_length, device=device) + model.set_kv_cache( + batch_size=batch_size, + max_seq_length=max_seq_length, + device=device, + ) return llm, model @@ -41,8 +45,9 @@ def test_batched_equivalence(tmp_path): model_name = "microsoft/phi-2" download_from_hub(repo_id=model_name, tokenizer_only=True, checkpoint_dir=tmp_path) - device = "cuda:0" + device = torch.device("cuda:0") batch_size = 3 + max_seq_length = 50 sample_kwargs = {"top_k": 1} llm: LLM = LLM.load( @@ -65,8 +70,18 @@ def test_batched_equivalence(tmp_path): batch_x1 = torch.stack([x] * batch_size, dim=0) # Single token generation baseline - tok_1 = next_token(model, input_pos_1, x.unsqueeze(0), **sample_kwargs) - tok_2 = next_token(model, input_pos_2, tok_1.unsqueeze(0), **sample_kwargs) + tok_1 = next_token( + model=model, + x=x.unsqueeze(0), + input_pos=0, + **sample_kwargs, + ) + tok_2 = next_token( + model=model, + x=tok_1.unsqueeze(0), + input_pos=x.shape[0], + **sample_kwargs, + ) assert tok_1.ndim == 1 assert tok_2.ndim == 1 @@ -75,10 +90,24 @@ def test_batched_equivalence(tmp_path): # Switch to batched generation model.clear_kv_cache() - model.set_kv_cache(batch_size=batch_size, max_seq_length=50, device="cuda:0") + model.set_kv_cache( + batch_size=batch_size, + max_seq_length=max_seq_length, + device=device, + ) - toks_1: torch.Tensor = batched_next_token(model, input_pos_1, batch_x1, sample_kwargs) - toks_2: torch.Tensor = batched_next_token(model, input_pos_2, toks_1, sample_kwargs) + toks_1: torch.Tensor = batched_next_token( + model=model, + x=batch_x1, + input_pos=0, + kwargs=sample_kwargs, + ) + toks_2: torch.Tensor = batched_next_token( + model=model, + x=toks_1, + input_pos=x.shape[0], + kwargs=sample_kwargs, + ) assert toks_1.ndim == 2 assert toks_2.ndim == 2 @@ -97,27 +126,21 @@ def test_simple_batch(): config = litgpt.Config.from_name("microsoft/phi-2", padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=256) with torch.device("cuda"): m = litgpt.GPT(config).requires_grad_(False).eval() - x0 = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 7]]) - input_pos0 = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 2]]) - x1 = torch.tensor([[1], [2]]) - input_pos1 = torch.tensor([[4], [3]]) - - with torch.device("cuda"): + m.max_seq_length = 10 + # Note: This KV cache can be used throughout, also for batch size 1 + # It is reset whenever `input_pos=0` (prefill) m.set_kv_cache(2) - outs0 = m(x0, input_pos0) - outs1 = m(x1, input_pos1) - - with torch.device("cuda"): - m.set_kv_cache(1) + x0 = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) + x1 = torch.tensor([[1], [2]]) - outs0_ref0 = m(x0[:1], input_pos0[0]) - outs1_ref0 = m(x1[:1], input_pos1[0]) + outs0 = m(x0, input_pos=0) + outs1 = m(x1, input_pos=4) - with torch.device("cuda"): - m.set_kv_cache(1) + outs0_ref0 = m(x0[:1], input_pos=0) + outs1_ref0 = m(x1[:1], input_pos=4) - outs0_ref1 = m(x0[1:], input_pos0[1]) - outs1_ref1 = m(x1[1:], input_pos1[1]) + outs0_ref1 = m(x0[1:], input_pos=0) + outs1_ref1 = m(x1[1:], input_pos=4) outs0_ref = torch.cat([outs0_ref0, outs0_ref1]) outs1_ref = torch.cat([outs1_ref0, outs1_ref1]) @@ -133,7 +156,7 @@ def test_simple_batch(): def test_batch_generate(tmp_path): torch.use_deterministic_algorithms(True) - device = "cuda:0" + device = torch.device("cuda:0") batch_size = 3 sample_kwargs = {"top_k": 1} llm, model = create_llm(tmp_path, batch_size, 50, device) @@ -151,7 +174,7 @@ def test_batch_generate(tmp_path): # Generate tokens tokens = [] for l in batched_generate_fn( - model, + model=model, prompts=batch_x, max_returned_tokens=50, sample_args=sample_kwargs, @@ -216,10 +239,10 @@ def find_unique_stop(triplets): # Now we generate again, stopping early at the stop tokens. tokens = [] for l in batched_generate_fn( - model, + model=model, prompts=batch_x, max_returned_tokens=50, - stop_tokens=[(s,) for s in stops], + stop_tokens=tuple([s] for s in stops), sample_args=sample_kwargs, include_prompt=True, include_eos=False, @@ -257,7 +280,7 @@ def find_unique_stop(triplets): def test_batch_generate_equivalence(tmp_path): torch.use_deterministic_algorithms(True) - device = "cuda:0" + device = torch.device("cuda:0") batch_size = 3 sample_kwargs = {"top_k": 1} llm, model = create_llm(tmp_path, batch_size, 50, device) @@ -276,7 +299,7 @@ def test_batch_generate_equivalence(tmp_path): batch_tokens = [] for l in batched_generate_fn( - model, + model=model, prompts=batch_x, max_returned_tokens=50, sample_args=sample_kwargs, @@ -292,7 +315,7 @@ def test_batch_generate_equivalence(tmp_path): tokens = [] for t in generate_fn( - model, + model=model, prompt=batch_x[0], max_returned_tokens=50, include_prompt=False, diff --git a/tests/test_chat.py b/tests/test_chat.py index 3bfe49780d..fa9280e53c 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -46,6 +46,9 @@ def test_generate(monkeypatch, generated, stop_tokens, expected): model = MagicMock() model.config.block_size = 100 model.max_seq_length = 100 + # Mock methods called during generation + monkeypatch.setattr(model, "kv_cache_max_prefill_length", lambda: 80) + monkeypatch.setattr(model, "kv_cache_max_tokens_forward", lambda: 20) it = iter(generated) def multinomial(*_, **__): diff --git a/tests/test_model.py b/tests/test_model.py index 9f55f5d373..18a2e7ac71 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,8 +1,10 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +import math +import random from copy import deepcopy from functools import partial -from unittest import mock +from typing import Optional import pytest import torch @@ -32,9 +34,16 @@ from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM from transformers.models.qwen3 import Qwen3Config, Qwen3ForCausalLM +import litgpt.attention import litgpt.config as config_module from litgpt import GPT, Config -from litgpt.model import CausalSelfAttention, batched_index_copy_ +from litgpt.attention import ( + build_mask_cache, + build_mask_slice, + DefaultKeysAndValues, + scaled_dot_product_attention, +) +from litgpt.model import CausalSelfAttention from litgpt.scripts.convert_hf_checkpoint import ( copy_weights_falcon, copy_weights_gemma_2, @@ -46,7 +55,7 @@ copy_weights_qwen_3, ) from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved -from litgpt.utils import _RunIf +from litgpt.utils import _RunIf, batched_index_select @torch.inference_mode() @@ -706,7 +715,13 @@ def test_against_original_gemma(model_name, device, dtype): torch.set_default_dtype(dtype) T = 5 - ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86) + ours_config = Config.from_name( + model_name, + n_layer=2, + n_head=16, + n_embd=32, + intermediate_size=86, + ) theirs_config = GemmaConfig( vocab_size=ours_config.padded_vocab_size, hidden_size=ours_config.n_embd, @@ -771,6 +786,7 @@ def test_against_original_gemma_2(model_name, device, dtype): n_head=16, n_embd=32, intermediate_size=86, + rotary_percentage=1.0, # Gemma2 does not have this ) theirs_config = Gemma2Config( vocab_size=ours_config.padded_vocab_size, @@ -805,7 +821,6 @@ def test_against_original_gemma_2(model_name, device, dtype): # test end to end x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) - assert x.size(1) == T ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5) @@ -1270,56 +1285,20 @@ def test_model_compile(): model = GPT(model.config) model.set_kv_cache(2) - input_pos = torch.arange(model.config.block_size) - explanation = torch._dynamo.explain(model)(x, input_pos) + explanation = torch._dynamo.explain(model)(x) assert isinstance(explanation, debugging.ExplainOutput) assert explanation.graph_count == 1 assert explanation.graph_break_count == 0 -@torch.inference_mode() -@pytest.mark.parametrize( - "max_seq_length", (25, pytest.param(23, marks=pytest.mark.xfail(raises=IndexError, strict=True))) -) -@pytest.mark.flaky(reruns=5) -def test_kv_cache(max_seq_length): - config = Config(block_size=25, padded_vocab_size=5, n_layer=2, n_head=2, n_embd=8) - model = GPT(config) - idx = torch.randint(0, model.config.padded_vocab_size, (1, 5)) - max_new_tokens = 20 - model.max_seq_length = max_seq_length - model.set_kv_cache(1) - - def generate(logits): - logits = logits[:, -1:] - probs = torch.nn.functional.softmax(logits, dim=-1) - return torch.argmax(probs).unsqueeze(0).unsqueeze(0) - - x_no_cache = idx - x_cache = idx - input_pos = torch.arange(0, 5) - for _ in range(max_new_tokens): - logits_no_cache = model(x_no_cache[:, -max_seq_length:]) - out_no_cache = generate(logits_no_cache) - - logits_cache = model(x_cache, input_pos) - out_cache = generate(logits_cache) - - torch.testing.assert_close(out_no_cache, out_cache, rtol=0, atol=0) - - x_no_cache = torch.cat((x_no_cache, out_no_cache), dim=1) - x_cache = out_cache - input_pos = input_pos[-1:] + 1 - - @torch.inference_mode() def test_model_kv_cache_amp(): config = Config.from_name("pythia-14m", n_layer=2) model = GPT(config) - encoded = torch.arange(45) + encoded = torch.arange(45).view(1, -1) model.set_kv_cache(batch_size=1) with torch.autocast("cpu", torch.bfloat16): - output = model(encoded.unsqueeze(0), encoded) + output = model(encoded, input_pos=0) assert output.dtype is torch.bfloat16 @@ -1369,7 +1348,9 @@ def assert_sdpa_backend(original_fn, q, k, v, mask): pytest.xfail() for h in model.transformer.h: - h.attn.scaled_dot_product_attention = partial(assert_sdpa_backend, h.attn.scaled_dot_product_attention) + litgpt.attention.scaled_dot_product_attention = partial( + assert_sdpa_backend, litgpt.attention.scaled_dot_product_attention + ) if SUPPORTS_FLASH_ATTENTION: expected = SDPBackend.FLASH_ATTENTION @@ -1414,25 +1395,26 @@ def assert_sdpa_backend(original_fn, q, k, v, mask): model.max_seq_length = 1 model.set_kv_cache(2) x = torch.randint(0, 10, (2, 1), dtype=torch.int32) - input_pos = torch.tensor([0], dtype=torch.long) except torch.cuda.OutOfMemoryError: # best effort, if the GPU can load it pytest.xfail() for h in model.transformer.h: - h.attn.scaled_dot_product_attention = partial(assert_sdpa_backend, h.attn.scaled_dot_product_attention) + litgpt.attention.scaled_dot_product_attention = partial( + assert_sdpa_backend, litgpt.attention.scaled_dot_product_attention + ) if SUPPORTS_FLASH_ATTENTION: # flash attention does not support an attention mask expected = SDPBackend.MATH with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False): - model(x, input_pos) + model(x, input_pos=0) expected = ( SDPBackend.EFFICIENT_ATTENTION if config.head_size % 8 == 0 and config.n_query_groups != 1 else SDPBackend.MATH ) with torch.backends.cuda.sdp_kernel(enable_flash=False): - model(x, input_pos) + model(x, input_pos=0) @_RunIf(min_cuda_gpus=2, standalone=True) @@ -1463,68 +1445,6 @@ def test_reset_parameters_device(): assert model.cos.device.type == "cuda" -def test_batched_index_copy_modes(): - # Mock the torch.backends.mps.is_available() function to simulate MPS availability - with mock.patch("torch.backends.mps.is_available", return_value=True): - # Mock the device type to simulate the "mps" device - with mock.patch("torch.Tensor.device", new_callable=mock.PropertyMock) as mock_device: - mock_device.return_value = torch.device("mps") - - # Test case when idx.dim() == 1 - t_original_1 = torch.randn(3, 5) - dim_1 = 0 - idx_1 = torch.tensor([0, 2]) - val_1 = torch.randn(2, 5) - - t1_cpu = t_original_1.clone() - t1_mps = t_original_1.clone() - - # Perform the index copy on CPU - batched_index_copy_(t1_cpu, dim_1, idx_1, val_1) - - # Simulate the MPS index copy - idx_1_mps = idx_1 - val_1_mps = val_1 - batched_index_copy_(t1_mps, dim_1, idx_1_mps, val_1_mps) - assert torch.allclose(t1_cpu, t1_mps), "Mismatch with idx.dim() == 1 on mocked MPS" - - # Test case when idx.dim() == 2 - t_original_2 = torch.randn(2, 5, 4) - dim_2 = 1 - idx_2 = torch.tensor([[0, 2], [1, 3]]) - val_2 = torch.randn(2, 2, 4) - - t2_cpu = t_original_2.clone() - t2_mps = t_original_2.clone() - - # Perform the index copy on CPU - batched_index_copy_(t2_cpu, dim_2, idx_2, val_2) - - # Simulate the MPS index copy - idx_2_mps = idx_2 - val_2_mps = val_2 - batched_index_copy_(t2_mps, dim_2, idx_2_mps, val_2_mps) - assert torch.allclose(t2_cpu, t2_mps), "Mismatch with idx.dim() == 2 on mocked MPS" - - # Additional test with negative dimension - t_original_3 = torch.randn(2, 3, 4) - dim_3 = -2 - idx_3 = torch.tensor([[0, 1], [1, 2]]) - val_3 = torch.randn(2, 2, 4) - - t3_cpu = t_original_3.clone() - t3_mps = t_original_3.clone() - - # Perform the index copy on CPU - batched_index_copy_(t3_cpu, dim_3, idx_3, val_3) - - # Simulate the MPS index copy - idx_3_mps = idx_3 - val_3_mps = val_3 - batched_index_copy_(t3_mps, dim_3, idx_3_mps, val_3_mps) - assert torch.allclose(t3_cpu, t3_mps), "Mismatch with negative dimension on mocked MPS" - - def test_load_legacy_state_dict(): """Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers.""" config = Config( @@ -1589,20 +1509,119 @@ def test_rope_cos_sin_shapes_if_rope_n_elem_is_odd(rotary_percentage, final_dim) assert model.sin.shape == required_shape -def test_forward_with_without_input_pos_maxp1(): - batch_size = 3 - config = Config( - block_size=25, - padded_vocab_size=5, - n_layer=2, - n_head=8, - n_embd=16, - ) - model = GPT(config) - model.set_kv_cache(batch_size) - idx = torch.randint(0, config.padded_vocab_size, (1, 10)) - input_pos = torch.arange(1, 11) - input_pos_maxp1 = 11 - logits_with_maxp1 = model(idx, input_pos, input_pos_maxp1=input_pos_maxp1) - logits_no_maxp1 = model(idx, input_pos) - torch.testing.assert_close(logits_with_maxp1, logits_no_maxp1) +@pytest.mark.parametrize( + ("n_head", "n_query_groups"), + ( + (2, 1), + (4, 1), + (8, 4), + (12, 4), + (24, 8), + (9, 3), + ), +) +@torch.inference_mode() +def test_scaled_dot_product_attention(n_head, n_query_groups): + seed = 31415927 + random.seed(seed) + torch.random.manual_seed(seed) + num_repeats = 5 + dtype = torch.bfloat16 + + for repeat in range(num_repeats): + head_size = 2 ** random.randint(3, 6) + batch_size = random.randint(1, 5) + len_key = random.randint(16, 128) + is_causal = repeat % 2 == 0 + if is_causal: + len_query = len_key + elif repeat % 4 == 1: + len_query = random.randint(1, len_key // 2) + else: + len_query = 1 + shape = (batch_size, n_head, len_query, head_size) + query = torch.randn(shape, dtype=dtype) + shape = (batch_size, n_query_groups, len_key, head_size) + key = torch.randn(shape, dtype=dtype) + value = torch.randn(shape, dtype=dtype) + k_and_v = DefaultKeysAndValues(key, value) + scale = 1.0 / math.sqrt(head_size) + + result, scores = scaled_dot_product_attention( + query, + k_and_v, + scale=scale, + is_causal=is_causal, + ) + q_per_kv = n_head // n_query_groups + key_bc = key.repeat_interleave(q_per_kv, dim=1) + value_bc = value.repeat_interleave(q_per_kv, dim=1) + k_and_v_bc = DefaultKeysAndValues(key_bc, value_bc) + result_cmp, scores_cmp = scaled_dot_product_attention( + query, + k_and_v_bc, + scale=scale, + is_causal=is_causal, + ) + msg = ( + f"bs={batch_size}, hs={head_size}, nh_q={n_head}, nh_k={n_query_groups}, len_q={len_query}, len_k={len_key}" + ) + kwargs = dict(atol=0.0005, rtol=0.05) + torch.testing.assert_close(result, result_cmp, **kwargs), msg + torch.testing.assert_close(scores, scores_cmp, **kwargs), msg + + +@pytest.mark.parametrize( + ("sliding_window_size", "batch_size", "n_query_groups"), + ( + (None, 1, 1), + (None, 4, 16), + (4, 1, 1), + (4, 2, 32), + (128, 1, 1), + (128, 4, 16), + ), +) +@torch.inference_mode() +def test_build_mask_slice( + sliding_window_size: Optional[int], + batch_size: int, + n_query_groups: int, +): + seed = 31415927 + random.seed(seed) + torch.random.manual_seed(seed) + num_repeats = 10 + dtype = torch.bfloat16 + device = torch.device("cpu") + + for _ in range(num_repeats): + seq_len = random.randint(16, 256) + full_mask = build_mask_cache(seq_len, sliding_window_size, device, dtype) + input_pos = random.randint(1, seq_len - 1) + num = random.randint(1, min(16, seq_len - input_pos)) + cache_length = random.randint(8, seq_len - 4) + token_positions = torch.zeros( + (batch_size, n_query_groups, cache_length), + dtype=torch.int64, + device=device, + ) + for bs in range(batch_size): + for nq in range(n_query_groups): + token_positions[bs, nq, :] = torch.randperm( + seq_len, device=device, + )[:cache_length] + mask = build_mask_slice( + input_pos=input_pos, + num=num, + token_positions=token_positions, + dtype=dtype, + device=device, + sliding_window_size=sliding_window_size, + ) + mask_cmp = batched_index_select( + full_mask[input_pos: (input_pos + num), :], + dim=1, + idx=token_positions, + ) + torch.testing.assert_close(mask, mask_cmp) From e3fa10e5b09fec2770e643bdbe642d77417ed12d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 30 May 2025 07:08:25 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- litgpt/attention.py | 29 ++++++++++++++++++++--------- litgpt/config.py | 1 + tests/test_model.py | 7 ++++--- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/litgpt/attention.py b/litgpt/attention.py index 6315cedca2..7cc15be0c5 100644 --- a/litgpt/attention.py +++ b/litgpt/attention.py @@ -323,7 +323,10 @@ def build_mask_cache( """ # Usual causal mask: mask = torch.ones( - max_seq_length, max_seq_length, device=device, dtype=dtype, + max_seq_length, + max_seq_length, + device=device, + dtype=dtype, ).triu(diagonal=1) if sliding_window_size is not None: mask += torch.ones_like(mask).tril(diagonal=-sliding_window_size) @@ -363,15 +366,23 @@ def build_mask_slice( tp_dtype = token_positions.dtype token_positions = token_positions.unsqueeze(2).to(device=device) kwargs = dict(device=device, dtype=tp_dtype) - bool_mask = torch.arange( - input_pos, input_pos + num, **kwargs, - ).view(1, 1, -1, 1) < token_positions - if sliding_window_size is not None: - extra_mask = torch.arange( - input_pos - sliding_window_size, - input_pos + num - sliding_window_size, + bool_mask = ( + torch.arange( + input_pos, + input_pos + num, **kwargs, - ).view(1, 1, -1, 1) >= token_positions + ).view(1, 1, -1, 1) + < token_positions + ) + if sliding_window_size is not None: + extra_mask = ( + torch.arange( + input_pos - sliding_window_size, + input_pos + num - sliding_window_size, + **kwargs, + ).view(1, 1, -1, 1) + >= token_positions + ) bool_mask += extra_mask mask = torch.zeros(bool_mask.shape, dtype=dtype, device=device) mask.masked_fill_(bool_mask, torch.finfo(dtype).min) diff --git a/litgpt/config.py b/litgpt/config.py index a5cb2ebc5f..6217d4b905 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -22,6 +22,7 @@ def find_multiple(n: int, k: int) -> int: return n return n + k - (n % k) + # See `Config.start_of_layer_hook`. A start of layer hook is called just before # a layer is computed. The call is `hook(x, block_idx, input_pos)`, where # `x` is the layer input, `block_idx` the number of the layer, and `input_pos` diff --git a/tests/test_model.py b/tests/test_model.py index 18a2e7ac71..3605c4971d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -38,9 +38,9 @@ import litgpt.config as config_module from litgpt import GPT, Config from litgpt.attention import ( + DefaultKeysAndValues, build_mask_cache, build_mask_slice, - DefaultKeysAndValues, scaled_dot_product_attention, ) from litgpt.model import CausalSelfAttention @@ -1609,7 +1609,8 @@ def test_build_mask_slice( for bs in range(batch_size): for nq in range(n_query_groups): token_positions[bs, nq, :] = torch.randperm( - seq_len, device=device, + seq_len, + device=device, )[:cache_length] mask = build_mask_slice( input_pos=input_pos, @@ -1620,7 +1621,7 @@ def test_build_mask_slice( sliding_window_size=sliding_window_size, ) mask_cmp = batched_index_select( - full_mask[input_pos: (input_pos + num), :], + full_mask[input_pos : (input_pos + num), :], dim=1, idx=token_positions, )