diff --git a/.gitignore b/.gitignore index e843852..dc3d4e4 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,9 @@ wandb # downloaded by our tests original_model.py -original_adapter.py \ No newline at end of file +original_adapter.py + +.vscode + +torch_compile_debug +results \ No newline at end of file diff --git a/GPTQ.py b/GPTQ.py index e1279bd..546cb7e 100644 --- a/GPTQ.py +++ b/GPTQ.py @@ -15,7 +15,7 @@ from eval import ( setup_cache_padded_seq_input_pos_max_seq_length_for_prefill, - GPTFastEvalWrapper + GPTFastEvalWrapper, ) @@ -63,7 +63,6 @@ def __init__( ) self.pad_calibration_inputs = False - def add_input(self, args): if self.inputs is None: self.inputs = [MultiInput([arg]) for arg in args] @@ -113,7 +112,6 @@ def _model_call(self, inps): ) - class MultiInput: def __init__(self, inputs): self.values = list(inputs) @@ -126,7 +124,9 @@ def __getitem__(self, slice): return MultiInput(self.values[slice]) def cuda(self): - self.values = [val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values] + self.values = [ + val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values + ] class GenericGPTQRunner(fx.Interpreter): @@ -235,7 +235,12 @@ def tensors_to_cuda(args): ) transposed_args = list( zip( - *[x.values if isinstance(x, MultiInput) else [x] * multi_input_count for x in flat_args] + *[ + x.values + if isinstance(x, MultiInput) + else [x] * multi_input_count + for x in flat_args + ] ) ) else: @@ -244,8 +249,8 @@ def tensors_to_cuda(args): # check whether we apply GPTQ to this module quantize_linear = ( - (target == aten.linear.default) # if its a linear - and id(args[1]) in self.id_to_name # and if we know the layer name + (target == aten.linear.default) # if its a linear + and id(args[1]) in self.id_to_name # and if we know the layer name and not skip_quant # and if we weren't told to skip quantization # and if the skip_layer_func doesn't say we should skip and not (self.skip_layer_func is not None and self.skip_layer_func(args[1])) @@ -259,9 +264,7 @@ def tensors_to_cuda(args): inp = tensors_to_cuda(inp) cur_args, cur_kwargs = tree_unflatten(inp, spec) - if ( - quantize_linear - ): # calculate H instead of output (will run the linear eventually with updated weight) + if quantize_linear: # calculate H instead of output (will run the linear eventually with updated weight) x = cur_args[0].float() shape = x.shape n = 1 if len(shape) == 2 else shape[0] @@ -333,11 +336,14 @@ def SQNR(x, y): target, (args[0][:2], DQ2, *args[2:]), kwargs, skip_quant=True ) - print("SQNR for output without GPTQ (should be less than above)", - torch.cat([ + print( + "SQNR for output without GPTQ (should be less than above)", + torch.cat( + [ SQNR(old.cpu(), old_q.cpu()).unsqueeze(0) for (old, old_q) in zip(old_out.values, old_q_out.values) - ]).mean(), + ] + ).mean(), ) return new_out diff --git a/README.md b/README.md index 5cef03c..be36666 100644 --- a/README.md +++ b/README.md @@ -1,207 +1,29 @@ -# gpt-fast -Simple and efficient pytorch-native transformer text generation. +# Fast-Compress -Featuring: -1. Very low latency -2. <1000 lines of python -3. No dependencies other than PyTorch and sentencepiece -4. int8/int4 quantization -5. Speculative decoding -6. Tensor parallelism -7. Supports Nvidia and AMD GPUs +**This a WIP - do not use unless you are interested in contributing to the ongoing project.** -This is *NOT* intended to be a "framework" or "library" - it is intended to show off what kind of performance you can get with native PyTorch :) Please copy-paste and fork as you desire. +This repo extends [GPT-Fast](https://github.com/pytorch-labs/gpt-fast) by adding SOTA KV Cache compression methods. -For an in-depth walkthrough of what's in this codebase, see this [blog post](https://pytorch.org/blog/accelerating-generative-ai-2/). - -## Examples -In the spirit of keeping the repo minimal, here are various examples of extensions you can make to gpt-fast as PRs. -- [Gemma support](https://github.com/pytorch-labs/gpt-fast/pull/115) -## Supported Models - -### LLaMA family -Please check the rest of this page about benchmark of LLaMA family models. - -### Mixtral 8x7B -We also supported [Mixtral 8x7B](https://mistral.ai/news/mixtral-of-experts/) which is a high-quality sparse mixture of experts (MoE) model, the average token generation rates are: - -| | 1 GPU | 2 GPU | 4 GPU | 8 GPU | -|------------------|---------|-----------|--------|------------| -|baseline(bfloat16)| OOM | 96.67 | 155.35 | 227.82 | -| int8 | 97.92 | 155.03 | 216.87 | 279.35 | - -Note that the benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens). - -For more details about Mixtral 8x7B, please check [this page](./mixtral-moe) or this [note](https://thonking.substack.com/p/short-supporting-mixtral-in-gpt-fast). - -## Community - -Projects inspired by gpt-fast in the community: - -- [gpt-blazing](https://github.com/armed-gpt/gpt-blazing): applies the same performance optimization strategy to more models (e.g., baichuan2). -- [gptfast](https://github.com/MDK8888/GPTFast): applies a subset of the performance optimizations to all Huggingface models -- [gpt-accelera](https://github.com/Edward-Sun/gpt-accelera): extends `gpt-fast` to SFT/RM/PPO training and batched inference to optimize the throughput +When done, it *will* serve as an open-source, hackable toolkit to accelerate research onto memory efficient inference. ## Installation [Download PyTorch nightly](https://pytorch.org/get-started/locally/) -Install sentencepiece and huggingface_hub ```bash -pip install sentencepiece huggingface_hub +pip install packaging ninja +MAX_JOBS=8 pip install flash-attn --no-build-isolation # Set MAX_JOBS to a lower value if you get OOM errors. +pip install -r requirements.txt ``` -To download llama models, go to https://huggingface.co/meta-llama/Llama-2-7b and go through steps to obtain access. -Then login with `huggingface-cli login` - +After logging in with `huggingface-cli login`, run - -## Downloading Weights -Models tested/supported -```text -tinyllamas/stories{15,42,100} -openlm-research/open_llama_7b -meta-llama/Llama-2-7b-chat-hf -meta-llama/Llama-2-13b-chat-hf -meta-llama/Llama-2-70b-chat-hf -codellama/CodeLlama-7b-Python-hf -codellama/CodeLlama-34b-Python-hf -mistralai/Mistral-7B-v0.1 -mistralai/Mistral-7B-Instruct-v0.1 -mistralai/Mistral-7B-Instruct-v0.2 -``` - -For example, to convert Llama-2-7b-chat-hf ```bash -export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf -./scripts/prepare.sh $MODEL_REPO +bash scripts/prepare_llama3.sh ``` -## Benchmarks -Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens). - -| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | -| -------- | ------- | ------ | ------ | -| Llama-2-7B | Base | 104.9 | 1397.31 | -| | 8-bit | 155.58 | 1069.20 | -| | 4-bit (G=32) | 196.80 | 862.69 | -| Llama-2-70B | Base | OOM || -| | 8-bit | 19.13 | 1322.58 | -| | 4-bit (G=32) | 25.25 | 1097.66 | - -### Speculative Sampling -[Verifier: Llama-70B (int4), Draft: Llama-7B (int4)](./scripts/speculate_70B_int4.sh): 48.4 tok/s +This will create necessary model and tokenizer files for`Meta-Llama-3-8B-Instruct` within `./checkpoints`. It will also create a smaller model for debugging purposes only, called `Meta-Llama-3-8B-Instruct-4-Layers`. This model removes all layers except for the first 4. It's quicker to load but will generate nonsense, so only use for debugging. -### Tensor Parallelism -| Model | Number of GPUs | Tokens/Second | Memory Bandwidth (GB/s) | -| -------- | ------- | ------ | ------ | -| Llama-2-7B | 1 | 104.9 | 1397.31 | -| | 2 | 168.84 | 1181.99 | -| | 4 | 254.02 | 955.83 | -| | 8 | 328.43 | 704.10 | -| Llama-2-70B | 1 | OOM | | -| | 2 | 21.32 | 1481.87 | -| | 4 | 38.01 | 1340.76 | -| | 8 | 62.50 | 1135.29 | +## Usage -### Tensor Parallelism + Quantization -| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | -| -------- | ------- | ------ | ------ | -| Llama-2-70B | Base | 62.50 | 1135.29 | -| | 8-bit | 80.44 | 752.04 | -| | 4-bit (G=32) | 90.77 | 548.10 | - -### AMD -Benchmarks run on one GCD of a MI-250x. - -| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | -| -------- | ------- | ------ | ------ | -| Llama-2-7B | Base | 76.33 | 1028.70 | -| | 8-bit | 101.86 | 700.06 | - -## Generate Text - -Model definition in `model.py`, generation code in `generate.py`. - -```bash -python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --prompt "Hello, my name is" -``` - -To squeeze out a little bit more performance, you can also compile the prefill with `--compile_prefill`. This will increase compilation times though. - -## Quantization -Choose device to use by -```bash -# The current support devices: cuda, cpu -export DEVICE=cuda -``` -### Int8 Weight-Only Quantization -To generate this version of the model -```bash -# Spits out model at checkpoints/$MODEL_REPO/model_int8.pth -python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int8 -``` -To run with int8, just pass the int8 checkpoint to generate.py. -```bash -python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth --device $DEVICE ``` - -### Int4 Weight-Only Quantization -To generate int4 version of model -```bash -# Spits out model at checkpoints/$MODEL_REPO/model_int4.g32.$DEVICE.pth -python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 --groupsize 32 -``` - -To run with int4, just pass the int4 checkpoint to generate.py. -```bash -python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --compile -``` - -## Speculative Sampling -To generate with speculative sampling (DRAFT_MODEL_REPO should point to a smaller model compared with MODEL_REPO). - -In this example, the "smaller" model is just the int8 quantized version of the model. -``` -export DRAFT_MODEL_REPO=meta-llama/Llama-2-7b-chat-hf -python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --draft_checkpoint_path checkpoints/$DRAFT_MODEL_REPO/model_int8.pth -``` - -Note: Running on an A100 80GB, albeit power-limited to 330 watts. Empirically, seems like peak bandwidth is about 1700 GB/s. - - -## Tensor Parallelism -```bash -ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=2 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth -``` - -## Experimental -### Evaluation -We use the EleutherAI evaluation harness to evaluate our model accuracy. To evaluate the accuracy, make sure the evaluation harness is installed and pass your model checkpoint and desired tasks to eval.py. - -```bash -python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --compile --tasks hellaswag winogrande -``` - -Note: Generative tasks are currently not supported for gpt-fast - -Installation Instructions for the evaluation harness: https://github.com/EleutherAI/lm-evaluation-harness/tree/master#install - -### GPTQ -We have a pure pytorch implementation of GPTQ that utilizes torch._dynamo.export to access the model structure. You can generate a GPTQ quantized -version of int4 quantization by using the same command to quantize it but adding 'gptq' to the quantization mode i.e. -```bash -# Spits out model at checkpoints/$MODEL_REPO/model_int4-gptq.g32.pth -python quantize.py --mode int4-gptq --calibration_tasks wikitext --calibration_seq_length 2048 -``` - -You can then eval or generate text with this model in the same way as above. - -## License - -`gpt-fast` is released under the [BSD 3](https://github.com/pytorch-labs/gpt-fast/main/LICENSE) license. - -## Acknowledgements -Thanks to: -* Lightning AI for supporting pytorch and work in flash attention, int8 quantization, and LoRA fine-tuning. -* GGML for driving forward fast, on device inference of LLMs -* Karpathy for spearheading simple, interpretable and fast LLM implementations -* MLC-LLM for pushing 4-bit quantization performance on heterogeneous hardware +python generate.py --compile --cache_strategy full --prompt "short_prompt_long_output.txt" +``` \ No newline at end of file diff --git a/attention_utils.py b/attention_utils.py new file mode 100644 index 0000000..4c670a7 --- /dev/null +++ b/attention_utils.py @@ -0,0 +1,45 @@ +import math +from typing import Tuple + +import torch +from torch.nn import functional as F + + +def scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + scale=None, + return_attn=False, + **kwargs, +) -> Tuple[torch.Tensor, torch.Tensor | None]: + """ + Uses naive PyTorch sdpa implementation if we need to return_attn. Otherwise use the optimized version. + + The naive implementation will be optimized later. + """ + if not return_attn: + return F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + dropout_p=dropout_p, + scale=scale, + ), None + B, H, L, S = query.size(0), query.size(1), query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_weight = query @ key.transpose(-2, -1) * scale_factor + + if attn_mask is not None: + attn_bias = torch.zeros(B, H, L, S, dtype=query.dtype, device=query.device) + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + attn_weight += attn_bias + + # TODO if returning attn_weight, should we just modify the attn_weight tensor to be attn_prob? + attn_prob = torch.softmax(attn_weight, dim=-1) + attn_prob = torch.dropout(attn_prob, dropout_p, train=True) + return_logits = kwargs.get("return_attn_logits", False) + return attn_prob @ value, attn_weight if return_logits else attn_prob diff --git a/cache.py b/cache.py new file mode 100644 index 0000000..6f51e0a --- /dev/null +++ b/cache.py @@ -0,0 +1,1344 @@ +import regex as re +from abc import ABC, abstractmethod +from typing import Tuple, Callable + +import math +import torch +import torch.nn as nn +from prompt_compression import prompt_compressor_constructor +import argparse + + +def add_cache_arguments(parser: argparse.ArgumentParser): + group = parser.add_argument_group("cache_args") + # KV-Cache Kwargs + group.add_argument( + "--max_cache_length", + type=float, + default=[1.0], + nargs="+", + help="Cache size per layer. If len < n layers, the values are tiled. Must have len divisible by n layers. \ + If 0 < x <= 1, it is percent of |prompt| + max new tokens. Otherwise, if > 1, its the maximum size.", + ) + strategies = ["full", "random", "window", "scissor", "l2", "fastgen", "gist"] + debug_strategies = [f"debug_{strategy}" for strategy in strategies] + strategies.extend(debug_strategies) + + group.add_argument( + "--cache_strategy", + default="full", + choices=strategies, + ) + + # Dealing with Long Prompts + parser.add_argument( + "--feed_long_prompts", + default=False, + action="store_true", + help="If True and |prompt| > max_cache_length, prefill with prompt[:max_cache_length], and feed prompt[max_cache_length:] sequentially.", + ) + group.add_argument( + "--prompt_compression_strategy", # This doesn't matter if args.feed_long_prompts is True + default="recent_global", + choices=["recent_global", "snapkv", "l2", "random"], + help="If |prompt| exceeds max_cache_length, we need to specify a strategy for compressing it to max_cache_length.", + ) + + # Optional Cache Kwargs depending on cache_strategy + group.add_argument( + "--global_tokens", + default=1, + type=int, + help="The number of initial tokens to always include in the KV-Cache. \ + If using window strategy, the actual window becomes max_cache_length - global_tokens.", + ) + + # Locality + group.add_argument( + "--recent_window", # NB: for KVCacheWindow, recent_window is implicitly set to self.max_cache_length - self.global_tokens. + default=10, # 10 is default specified in ScissorHands paper ("r" in Algorithm 2). + type=float, # If < 1, it is a fraction of max_cache_length. + help="The number of recently generated tokens to always spare from eviction.", + ) + + # Scissorhands-specific Hyperparameters (--cache_strategy == "scissor") + ## See Algorithm 1 & 2 in arxiv.org/abs/2305.17118 + group.add_argument( + "--history_window_size", # Equivalent to "m" in Algorithm 2. + default=400, # 400 is default specified in paper. + type=int, + help="The number of past tokens to consider when computing 'Heavy Hitters' in the KV-Cache.", + ) + group.add_argument( + "--drop_amount", # Equivalent to "m" in Algorithm 2. + default=0.0, # 0 means we re-calculate eviction token every time. 0.4 is default specified in paper. + type=float, + help="The number of tokens to evict KV-Cache reaches capacity (max_cache_length). Expressed as a fraction of max_cache_length.", + ) + group.add_argument( + "--attn_thresholding", + default=False, + action="store_true", + help="Whether to accumulate number of times a token was unimportant (binary) versus raw un-normalized probabilities. If true, more memory efficient.", + ) + group.add_argument( + "--attn_record_freq", + default=1, + type=int, + help="How often to record attention weights for the ScissorHands cache.", + ) + + # FastGen-specific Hyperparameters (--cache_strategy == "fastgen") + parser.add_argument( + "--heavy_hitter_frac", + default=0.3, + type=float, + help="Fraction of max_cache_length to consider as heavy hitters in the KV-cache.", + ) + + parser.add_argument( + "--min_recovery_frac", + default=0.9, + type=float, + help="Mininum fraction of recovered attentions (|compressed_attn - uncompressed_attn| < epsilon). The lower the value, the higher the compression.", + ) + + +def cache_compatibility(args): + if args.cache_strategy in ("full", "gist"): + # Full implies no compression, which means --max_cache_length = [1.0] (same size as prompt + max_new_tokens) + assert all( + [l == 1.0 for l in args.max_cache_length] + ), "Full cache strategy only supports max_cache_length=1.0." + + if args.cache_strategy == "gist": + assert "gist" in str(args.checkpoint_path), "You must provide a gist token id for the gist cache." + + # Attention-based eviction policies must use an attention-based prompt compressor + if args.cache_strategy in {"scissor"}: + assert ( + args.prompt_compression_strategy == "snapkv" + ), 'Scissor requires "snapkv" prompt compression strategy' + + print("The cache argument values you provided appear compatible with each other!") + + +def create_window_attention_mask(seq_len, window_size, device): + # Initialize the mask tensor with zeros + mask = torch.zeros(seq_len, seq_len, dtype=torch.bool, device=device) + for i in range(seq_len): + mask[i, max(0, i - window_size) : i] = True + return mask + + +class KVCache(ABC, nn.Module): + # Define which hyperparameters are relevant for the cache. + # Override as needed for sub-classes. + relevant_kwargs = ["max_cache_length", "max_seq_length", "global_tokens"] + + def __init__( + self, + max_batch_size, + n_heads, + head_dim, + dtype=torch.bfloat16, + head_specific=False, + variable_length=False, + **kwargs, + ): + super().__init__() + + # Assign each kwarg as an attribute of the class + for key, value in kwargs.items(): + setattr(self, key, value) + + cache_shape = (max_batch_size, n_heads, self.max_cache_length, head_dim) + k_cache = torch.zeros(cache_shape, dtype=dtype) + v_cache = torch.zeros(cache_shape, dtype=dtype) + self.register_buffer("k_cache", k_cache) + self.register_buffer("v_cache", v_cache) + + # Can we evict different tokens for different heads? + # If the answer is yes, we need to store self.pos for each head. + self.head_specific = head_specific + self.register_buffer( + "pos", # Track pos to keep track of the original positions of each item in cache. + torch.full( + ( + max_batch_size, + n_heads if head_specific else 1, + self.max_cache_length, + ), + -1, + dtype=torch.int, + ), + ) + self.register_buffer( + "cache_cts", + torch.zeros((n_heads if variable_length else 1), dtype=torch.int), + ) + + # Incase the |prompt| > max_cache_length, we need to compress the prompt which requires a separate strategy + self.prompt_compressor = ( + None + if self.prompt_compression_strategy is None + else prompt_compressor_constructor(self.prompt_compression_strategy)( + head_specific=self.head_specific, **kwargs + ) + ) + + # This turns True when the global tokens are fully filled + self.global_filled = self.global_tokens == 0 + self.always_keep_prompt = self.global_tokens == -1 + + # KVCacheFastGen requires profiling attention heads during prefill. This must be handled with separate callback. + self.prefill_attn_callback = None + + def reset(self): + """ + Resets the cache to its initial state for a new example. + + NB: For more performance, don't reset k_cache and v_cache since we overwrite them in update. + """ + self.k_cache.zero_() + self.v_cache.zero_() + self.cache_cts.zero_() + self.pos.fill_(-1) + if self.always_keep_prompt: + self.global_tokens = ( + -1 + ) # -1 means we will resize it to the prompt size during prefill + self.global_filled = self.global_tokens == 0 + + def return_attn(self): + """ + Returns whether the cache requires attention weights for cache management. + """ + return False + + def compute_statistics(self, seq_len): + """ + Computes statistics about the cache. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The cache size, the number of tokens inserted, and the compression ratio. + """ + return { + "compression_ratio": self.compression_ratio(seq_len).item(), + } + + def compression_ratio(self, seq_len): + """ + Returns the compression ratio of the cache. + """ + # Final token isn't passed to cache so must -1 from seq_len + n = seq_len - 1 + return ((n - torch.clamp_max(self.cache_cts, self.max_cache_length)) / n).mean() + + def return_kv_cache(self): + # Truncate the cache based on number of insertions. It will be at the end since we prefill in-order. + k = ( + self.k_cache[:, :, : self.cache_cts, :] + if self.cache_cts < self.max_cache_length + else self.k_cache + ) + v = ( + self.v_cache[:, :, : self.cache_cts, :] + if self.cache_cts < self.max_cache_length + else self.v_cache + ) + + # Since we truncate there's no mask + mask = None + return k, v, mask + + def is_prefill(self): + """ + Returns whether the cache is in the prefill stage. + """ + # self.cache_cts is either tensor scalar or tensor of shape [num_heads] for variable-length caches. + return self.cache_cts.max() == 0 + + def compress_prompt_w_attn(self, input_pos, k_val, v_val, attn) -> None: + # If the prompt is longer than the cache, we need to compress it to fit cache and then store (update). + assert ( + input_pos.shape[-1] > self.max_cache_length + ), "You called compress_prompt in prefill stage yet prompt is not longer than max_cache_length." + input_pos, k_val, v_val, attn = self.prompt_compressor( + input_pos, k_val, v_val, attn + ) + # If you need input_ids you will have to pass them to prompt_compressor and have prompt_compressor return them in proper order. + # Only FastGen uses them for now and it has its own prefill callback called "profile_and_update". + self.update(input_pos, k_val, v_val, input_ids=None) + self.update_attn_history(attn) + + def compress_prompt( + self, input_pos, k_val, v_val + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Callable | None]: + mask = None # We will be performing causal attention on full inputs (mask won't be used) + if self.prompt_compressor.requires_attn(): + compress_callback = { + "func": ( + lambda input_pos, + input_ids, + k_val, + v_val, + attn: self.compress_prompt_w_attn(input_pos, k_val, v_val, attn) + ) + } + return k_val, v_val, mask, compress_callback + + # If we can compress without attention, we don't need to pass it as a callback. We can call update now and see if there's a different callback. + _, _, _, new_callback = self.update( + *self.prompt_compressor(input_pos, k_val, v_val) + ) + # Yet we return the un-compressed KV since during pre-fill we compute full causal attention. + return k_val, v_val, mask, new_callback + + def attn_history_callback(self) -> Callable | None: + """ + Returns a callback to update the attention history. + + Returns None if attention is not needed + """ + return ( + { + "func": lambda input_pos, + input_ids, + k_val, + v_val, + attn: self.update_attn_history(attn) + } + if self.return_attn() + else None + ) + + def update(self, input_pos, k_val, v_val, input_ids=None): + """ + Updates the cache with the given input positions, keys, and values. + + Parameters: + input_pos (torch.Tensor): A tensor of input positions. + k_val (torch.Tensor): A tensor of keys. + v_val (torch.Tensor): A tensor of values. + input_ids (torch.Tensor): A tensor of input ids. + + Returns: + Tuple[torch.Tensor, torch.Tensor, bool]: A tuple containing the updated cache of keys and values, + both truncated to the minimum of the current insertions and the maximum cache length. The last value + is a boolean return_attn indicating whether the cache requires attention weights. If True, the model + will call self.update_attn_history with the attention weights. + """ + is_prefill = self.is_prefill() + num_tokens = input_pos.shape[-1] + + # FastGen requires a special callback for prefill that profiles attention heads and updates the cache. + prefill_callback = None if not self.is_prefill() else self.prefill_attn_callback + if prefill_callback is not None: + mask = None + return k_val, v_val, mask, prefill_callback + + # If the prompt is longer than the cache, we need to compress it to fit cache and then store (update). + prompt_overflow = num_tokens > self.max_cache_length + if prompt_overflow: + return self.compress_prompt(input_pos, k_val, v_val) + + # k_val: [B, H, S, D] -> S is > 1 for prefill, 1 for new tokens + if is_prefill: + assert num_tokens > 1 + else: + assert num_tokens == 1 + + attn_history_callback = ( + { + "func": lambda input_pos, + input_ids, + k_val, + v_val, + attn: self.update_attn_history(attn) + } + if self.return_attn() + else None + ) + + self.cache_cts += self._update(input_pos, k_val, v_val, input_ids=input_ids) + + k, v, mask = self.return_kv_cache() + return k, v, mask, attn_history_callback + + @abstractmethod + def _update(self, input_pos, k_val, v_val, input_ids=None): + """ + Cache-specific update logic. + Takes in the input positions and the corresponding k and v values. + Modifies self.pos, self.k_cache, self.v_cache place. + + Returns a tensor indicating the number of tokens inserted - number of tokens evicted. + None is equivalent to 0. + """ + pass + + def fill_contiguous(self, input_pos, k_val, v_val, start=None, end=None): + """ + A simple utility to fill the cache and pos. + If start and end are provided, only fill the cache between those indices. + Otherwise, treat start as self.cache_cts and end as self.cache_cts + num_new_insertions. + Will also mark the global_tokens as they are updated. + """ + num_insertions = self.cache_cts[ + 0 + ] # If we are calling this function, self.cache_cts should be uniform across all heads + num_new_insertions = k_val.shape[2] + if start is None: + assert end is None + start = num_insertions + end = start + num_new_insertions + + self.pos[:, :, start:end] = input_pos.int() + + self.k_cache[:, :, start:end, :] = k_val + self.v_cache[:, :, start:end, :] = v_val + + if hasattr( + self, "global_tokens" + ): # If we have global tokens we need to mark them in self.pos + # Update global tokens to the prompt size if set to -1 + resize_global_tokens = self.global_tokens == -1 + if resize_global_tokens: + self.global_tokens = num_new_insertions + self.global_filled = self.global_filled or self.mark_global_tokens( + num_insertions + num_new_insertions + ) + + def fill_headwise(self, fill_indices, input_pos, k_val, v_val): + """ + Modifies the cache in-place with key-value pairs at given fill_indices. + + Args: + fill_indices (torch.Tensor): The indices specifying the positions to fill in the cache. + input_pos (torch.Tensor): The input positions corresponding to the fill_indices. + k_val (torch.Tensor): The key values to fill in the fill_indices slots. + v_val (torch.Tensor): The value values to fill in the fill_indices slots. + + Returns: + None + """ + # fill_indices [num_heads] or [1] + # input_pos [seq_len] or [num_heads, seq_len] + # k_val, v_val [batch_size, n_heads, seq_len, head_dim] + assert input_pos.shape[-1] == k_val.shape[2] == v_val.shape[2] + + # input_pos is either [seq_len] or [num_heads, seq_len] + pos_fill_indices = fill_indices.view(1, -1, 1) + cache_fill_indices = fill_indices.view(1, len(fill_indices), 1, 1).expand( + 1, k_val.shape[1], 1, k_val.shape[-1] + ) + input_pos = input_pos.view(1, -1, 1).expand(1, k_val.shape[1], 1).int() + self.pos.scatter_(2, pos_fill_indices, input_pos.int()) + self.k_cache.scatter_(2, cache_fill_indices, k_val) + self.v_cache.scatter_(2, cache_fill_indices, v_val) + + def update_attn_history(self, attn): + """ + Update the attention history with the most recent attention weights. + """ + raise Exception( + f"{self.__class__.__name__} requested return_attn=True but has not yet implemented a update_attn_history function." + ) + + def mark_global_tokens(self, num_total_insertions: int) -> bool: + """ + Update POS tensor to give global tokens highest priority. + + num_total_insertions: The total number of tokens inserted so far. The sum of cache_cts and num_new_insertions. + + Return a boolean indicating whether or not all global tokens were filled. + + If it returns True, this function won't be called again to save computation. + """ + assert hasattr( + self, "global_tokens" + ), "This cache does not have global tokens so we cannot mark them." + # Give self.pos an highest possible position value for global tokens so that they are not replaced + num_to_mark = min(self.global_tokens, num_total_insertions) + self.pos[:, :, :num_to_mark] = self.max_seq_length + return num_to_mark == self.global_tokens + +class KVCacheGist(KVCache): + relevant_kwargs = [ + 'gist_token_id', + 'max_cache_length' + ] + def __init__( + self, max_batch_size, n_heads, head_dim, dtype=torch.bfloat16, **kwargs + ): + # Gist does not use additional compression strategies + self.prompt_compression_strategy = None + self.global_tokens = 0 # No global tokens for gist cache + + self.gist_token_id = kwargs.pop('gist_token_id') + super().__init__( + max_batch_size, n_heads, head_dim, dtype, head_specific=False, **kwargs + ) + self.prefill_attn_callback = { + "func": self.profile_and_update, + "kwargs": {}, + } + self.register_buffer( + "ids", # Track ids to keep track of the original ids of each item in cache. required to determine gist mask in case of multi-batch inputs + torch.full( + ( + max_batch_size, + self.max_cache_length, + ), + -1, + dtype=torch.int, + ), + ) + + + def _update(self, input_pos, k_val, v_val, input_ids=None): + # input_pos: [S], k_val: [B, H, S, D], input_ids: [B, S] + + self.fill_contiguous(input_pos, k_val, v_val) + self.ids[:, self.cache_cts[0]:self.cache_cts[0]+input_ids.shape[-1]] = input_ids + return input_pos.shape[-1] + + def profile_and_update(self, input_pos, input_ids, k_val, v_val, attn): + assert self.is_prefill(), "Should only be profiling during prefill stage." + + gist_pos = torch.where(input_ids == self.gist_token_id)[-1].min().cpu().item() # use lowest position of gist token in case of multi batch inputs + seq_len = input_pos.shape[-1] + input_pos = input_pos[gist_pos:] + input_ids = input_ids[:, gist_pos:] + k_val = k_val[:, :, gist_pos:, :] + v_val = v_val[:, :, gist_pos:, :] + + self.fill_contiguous(input_pos, k_val, v_val) + self.ids[:, self.cache_cts[0]:self.cache_cts[0]+input_ids.shape[-1]] = input_ids + self.cache_cts[0] = input_pos.shape[-1] + + def return_kv_cache(self): + k, v, mask = super().return_kv_cache() + mask_shape = (k.shape[0], k.shape[1], 1, k.shape[-2]) + gist_mask = torch.ones(mask_shape, dtype=torch.bool).to(k.device) + gist_token_positions = torch.stack(torch.where(self.ids == self.gist_token_id)).T + for position in gist_token_positions: + gist_mask[position[0], :, :, :position[1]] = False + return k, v, gist_mask + +class KVCacheFull(KVCache): + def __init__( + self, max_batch_size, n_heads, head_dim, dtype=torch.bfloat16, **kwargs + ): + # Never any prompt compression for full cache + self.prompt_compression_strategy = None + self.global_tokens = 0 # No global tokens for full cache (they are all global) + super().__init__( + max_batch_size, n_heads, head_dim, dtype, head_specific=False, **kwargs + ) + + def _update(self, input_pos, k_val, v_val, input_ids=None): + # input_pos: [S], k_val: [B, H, S, D] + self.fill_contiguous(input_pos, k_val, v_val) + return input_pos.shape[-1] + + +class KVCacheRandom(KVCache): + relevant_kwargs = [ + "max_cache_length", + "max_seq_length", + "global_tokens", + "prompt_compression_strategy", + ] + + def __init__( + self, max_batch_size, n_heads, head_dim, dtype=torch.bfloat16, **kwargs + ): + super().__init__( + max_batch_size, n_heads, head_dim, dtype, head_specific=False, **kwargs + ) + + def _update(self, input_pos, k_val, v_val, input_ids=None): + start = end = None # We will fill the cache in order if start and end are None + + need_to_evict = self.cache_cts >= self.max_cache_length + if need_to_evict: # Select a spot at random + start = torch.randint(low=0, high=self.max_cache_length, size=(1,)) + end = start + 1 + + # Specify specific start and end indices + self.fill_contiguous(input_pos, k_val, v_val, start=start, end=end) + return input_pos.shape[-1] + + +class KVCacheWindow(KVCache): + relevant_kwargs = [ + "max_cache_length", + "max_seq_length", + "global_tokens", + "prompt_compression_strategy", + # NB: "recent_window" is ignored as a relevant kwarg. It is fixed to self.max_cache_length - self.global_tokens. + ] + + def __init__( + self, + max_batch_size, + n_heads, + head_dim, + dtype=torch.bfloat16, + head_specific=False, + variable_length=False, + **kwargs, + ): + super().__init__( + max_batch_size, + n_heads, + head_dim, + dtype, + head_specific, + variable_length, + **kwargs, + ) + + def _update(self, input_pos, k_val, v_val, input_ids=None): + start = end = None # We will fill the cache in order if start and end are None + + need_to_evict = self.cache_cts >= self.max_cache_length + if need_to_evict: # Identify the least recent spot + start = torch.argmin(self.pos) + assert ( + input_pos.shape[-1] == 1 + ), "Should only be passing 1 new token at a time after cache is filled!" + end = start + 1 + + # Specify specific start and end indices + self.fill_contiguous(input_pos, k_val, v_val, start=start, end=end) + + return input_pos.shape[-1] + + +class KVCacheL2(KVCacheWindow): + relevant_kwargs = [ + "max_cache_length", + "max_seq_length", + "global_tokens", + "recent_window", + "prompt_compression_strategy", + ] + + def __init__( + self, max_batch_size, n_heads, head_dim, dtype=torch.bfloat16, **kwargs + ): + super().__init__( + max_batch_size, n_heads, head_dim, dtype, head_specific=True, **kwargs + ) + + key_norm_shape = (max_batch_size, n_heads, self.max_cache_length) + self.register_buffer("key_norm", torch.zeros(key_norm_shape, dtype=dtype)) + + def _update(self, input_pos, k_val, v_val, input_ids=None): + key_norm = torch.linalg.vector_norm(k_val, ord=2, dim=-1) + + need_to_evict = self.cache_cts >= self.max_cache_length + + if need_to_evict: + # Set global and recent tokens to have lowest possible eviction score (-inf) + eviction_score = self.key_norm.masked_fill( + self.pos >= input_pos - self.recent_window, float("-inf") + ) + fill_indices = torch.argmax(eviction_score, dim=-1).squeeze(0) + self.fill_headwise(fill_indices, input_pos, k_val, v_val) + + # Do a scatter update to update the key norms + fill_indices = fill_indices.view(1, -1, 1) + self.key_norm.scatter_(2, fill_indices, key_norm) + else: # Insert into first unfilled spots + self.fill_contiguous(input_pos, k_val, v_val) + start, end = self.cache_cts, self.cache_cts + k_val.shape[2] + self.key_norm[:, :, start:end] = key_norm + + return input_pos.shape[-1] + + def update_attn_history(self, attn): + """ + This will be called if |prompt| > max_cache_length and SnapKV prompt compression is used. + Because L2 cache does not require attention weights, this function is a no-op. + """ + pass + + +class KVCacheScissorhands(KVCacheWindow): + relevant_kwargs = [ + "max_cache_length", + "max_seq_length", + "global_tokens", + "history_window_size", + "drop_amount", + "recent_window", + "attn_thresholding", + "prompt_compression_strategy", + "attn_record_freq", + ] + + def __init__( + self, + max_batch_size, + n_heads, + head_dim, + dtype=torch.bfloat16, + head_specific=True, + requires_eviction_queue=True, + variable_length=False, + **kwargs, + ): + super().__init__( + max_batch_size, + n_heads, + head_dim, + dtype, + head_specific, + variable_length, + **kwargs, + ) + + # Initialize a buffer for the attention histories + history_num_shape = ( + max_batch_size, + n_heads, + self.max_cache_length, + self.history_window_size, + ) + history_denom_shape = ( + max_batch_size, + n_heads, + self.max_cache_length, + ) + self.register_buffer( + "attn_history_num", + torch.zeros( + history_num_shape, dtype=torch.bool if self.attn_thresholding else dtype + ), + ) + + # Ideally, we could use the self.pos to track the number of times a token has been attended to + # But any change to cache management or how self.pos is stored would break this. + self.register_buffer( + "attn_history_denom", torch.zeros(history_denom_shape, dtype=torch.int32) + ) + + self.register_buffer("attn_counter", torch.zeros((), dtype=torch.int64)) + + assert self.recent_window >= self.attn_record_freq, ( + f"Since recent window ({self.recent_window}) < attention record frequency ({self.attn_record_freq}), you will get nan scores when " + "deciding which tokens to evict because >0 non-local tokens will have no attention history." + ) + + if requires_eviction_queue: + # Different eviction queue for each attention head + eviction_queue_shape = ( + max_batch_size, + n_heads, + self.drop_amount, + ) + self.register_buffer( + "eviction_queue", torch.zeros(eviction_queue_shape, dtype=torch.int32) + ) + # Start with an "empty queue" so that we can fill it up. + self.register_buffer("eviction_idx", torch.tensor(self.drop_amount)) + + assert self.queue_len() == 0 + + def reset(self): + super().reset() + self.attn_history_num.zero_() + self.attn_history_denom.zero_() + self.attn_counter.zero_() + if hasattr(self, "eviction_queue"): + self.eviction_queue.zero_() + # Start with an "empty queue" so that we can fill it up + self.eviction_idx.fill_(self.drop_amount) + assert self.queue_len() == 0 + + def queue_len(self): + return self.drop_amount - self.eviction_idx + + def return_attn(self) -> bool: + """ + Whether or not we need to return attention weights for cache management. + + We return attention weights if 3 conditions are met: + 1) The cache is not in the prefill stage. + 2) The number of tokens left in the eviction queue // the frequency with which we record attention < attention history window. + 3) The number of insertions is a multiple of the frequency with which we record attention. + + The number of tokens in eviction queue specifies how many turns before we need to re-calculate importance. + We only need to start recording once the number of steps until recomputation is equal to the recent window. + """ + + return ( + not self.is_prefill() + and self.queue_len() // self.attn_record_freq <= self.history_window_size + and ( + self.cache_cts.squeeze() % self.attn_record_freq == 0 + or self.attn_counter == 0 + ) + ) + + def update_attn_history(self, attn: torch.Tensor): + """ + Insert the most recent attention into the history buffer. + + If self.attn_thresholding = True, insert a binary indicator of whether the attention >= uniform attention. + """ + attn = attn.squeeze() + keys = attn.shape[1] + self.attn_history_num[ + :, :, :keys, self.attn_counter % self.history_window_size + ] = (attn >= 1 / keys).int() if self.attn_thresholding else attn + self.attn_history_denom[:, :, :keys] += 1 + self.attn_counter += 1 + + def refill_eviction_queue(self, input_pos: int): + # Identify the tokens with consistently "low" attentions + numerator = self.attn_history_num.sum(dim=-1).float() + # The denominator is the number of times this token's history has been recorded + # We only record most self.history_window_size recent scores so need to clamp it + denominator = self.attn_history_denom.clamp_max(self.history_window_size) + + avg_attn = numerator / denominator + + # Save the global & most recent tokens from being evicted + avg_attn.masked_fill_(self.pos >= input_pos - self.recent_window, 1.0) + + _, toks_to_evict = avg_attn.topk( + self.drop_amount, dim=-1, sorted=True, largest=False + ) + + # The eviction queue will be empty so just re-assign it + self.eviction_queue = toks_to_evict + self.eviction_idx.zero_() + + def _update(self, input_pos, k_val, v_val, input_ids=None): + num_new_tokens = input_pos.shape[-1] + need_to_evict = self.cache_cts >= self.max_cache_length + if not need_to_evict: # Insert into first unfilled spots + self.fill_contiguous(input_pos, k_val, v_val) + return num_new_tokens + + assert ( + self.global_filled + ), "Global tokens should be all marked as filled when cache is filled." + + # Refill the eviction queue only if it is empty (potentially expensive operation) + self.queue_len() > 0 or self.refill_eviction_queue(input_pos.item()) + + # Evict the next token in the queue (self.eviction_idx) and increment it + fill_indices = self.eviction_queue[0, :, self.eviction_idx] + self.eviction_idx += 1 + + self.fill_headwise(fill_indices, input_pos, k_val, v_val) + num_fill = fill_indices.view(1, -1, 1, 1).expand( + 1, -1, 1, self.attn_history_num.shape[-1] + ) + denom_fill = fill_indices.view(1, -1, 1) + self.attn_history_num.scatter_( + 2, num_fill, torch.zeros_like(num_fill, dtype=self.attn_history_num.dtype) + ) + self.attn_history_denom.scatter_( + 2, denom_fill, torch.zeros_like(denom_fill, dtype=torch.int32) + ) + + return num_new_tokens + + +class KVCacheFastGen(KVCacheScissorhands): + relevant_kwargs = [ + "max_cache_length", + "max_seq_length", + "history_window_size", + "recent_window", + "attn_thresholding", + "token_ids", + "prompt_compression_strategy", + "min_recovery_frac", + "heavy_hitter_frac", + ] + + strategies = [ + "special", + "special_punc", + "special_punc_heavy", + "special_punc_heavy_local", + "full", + ] + + def __init__( + self, + max_batch_size, + n_heads, + head_dim, + dtype=torch.bfloat16, + head_specific=True, + **kwargs, + ): + self.global_tokens = 0 # No global tokens for FastGen + self.attn_record_freq = 1 # We record attention every step for FastGen + super().__init__( + max_batch_size, + n_heads, + head_dim, + dtype, + head_specific, + requires_eviction_queue=False, + variable_length=True, + **kwargs, + ) + + special_ids = [torch.tensor(ids) for ids in kwargs["token_ids"]["special"]] + self.register_buffer("special_ids", torch.nested.nested_tensor(special_ids)) + + # Store the punctuation vocabulary ids + punc_ids = torch.Tensor(kwargs["token_ids"]["punctuation"]) + self.register_buffer("punc_ids", punc_ids) + # As well as a mask showing where punctuation ids are in the KV cache + # We store this to avoid re-computing the mask every time and having to store input_ids + + mask_shape = (max_batch_size, n_heads, self.max_cache_length) + self.register_buffer("special_mask", torch.zeros(mask_shape, dtype=torch.bool)) + self.register_buffer("punc_mask", torch.zeros(mask_shape, dtype=torch.bool)) + + # We need to use a mask since not all heads have same number of tokens. We can't simply truncate. + # 1 dimension stands for query dimension, which will always be 1 (next token) for KV cache attention. + kv_mask_shape = (max_batch_size, n_heads, 1, self.max_cache_length) + self.register_buffer("mask", torch.zeros(kv_mask_shape, dtype=torch.bool)) + + # NB: Kwargs are sdpa attention kwargs, not the kwargs for the "func" + self.prefill_attn_callback = { + "func": self.profile_and_update, + "kwargs": {"return_attn_logits": False}, + } + + def return_attn(self): + # We use a special callback for FastGen to profile the attention heads during prefill. + return not self.is_prefill() and self.requires_heavy_check + + def return_kv_cache(self): + return self.k_cache, self.v_cache, self.mask + + def eviction_idx_for_head(self, head_idx, input_pos, apply_window=False): + numerator = ( + self.attn_history_num[:, head_idx, : self.cache_cts[head_idx]] + .sum(dim=-1) + .float() + ) + # The denominator is the number of times this token's history has been recorded + # We only record most self.history_window_size recent scores so need to clamp it + denominator = self.attn_history_denom[ + :, head_idx, : self.cache_cts[head_idx] + ].clamp_max(self.history_window_size) + avg_attn = numerator / denominator + + # Save the special & punctuation tokens from being evicted + save_mask = torch.logical_or( + self.special_mask[:, head_idx, : self.cache_cts[head_idx]], + self.punc_mask[:, head_idx, : self.cache_cts[head_idx]], + ) + if apply_window: + save_mask = torch.logical_or( + save_mask, + self.pos[:, head_idx, : self.cache_cts[head_idx]] + >= input_pos - self.recent_window, + ) + + avg_attn.masked_fill_(save_mask, 1) + fill_idx = avg_attn.argmin(dim=-1) + + return fill_idx + + def select_fill_idx(self, strategy, head_idx, input_pos, is_punc: bool = False): + fill_idx = None + eviction_required = False + + def _end_idx(): + # We need to clone because self.cache_cts will be incremented later and we don't want to have fill_idx as a mutable reference + return min(self.max_cache_length - 1, self.cache_cts[head_idx].clone()) + + if strategy == KVCacheFastGen.strategies.index("special"): + pass # We are assuming we don't generate special tokens + elif strategy == KVCacheFastGen.strategies.index("special_punc"): + if is_punc: + fill_idx = _end_idx() + elif strategy == KVCacheFastGen.strategies.index( + "special_punc_heavy" + ) or strategy == KVCacheFastGen.strategies.index("special_punc_heavy_local"): + apply_window = strategy == KVCacheFastGen.strategies.index( + "special_punc_heavy_local" + ) + # If there's still room in the cache, just fill it in the next open slot + budget = ( + self.num_special + + self.num_punc + + (self.heavy_hitter_frac * self.max_cache_length) + ) + if apply_window: # We are also allowed budget for the recent tokens + budget += self.recent_window + + eviction_required = self.cache_cts[head_idx] >= budget + if eviction_required: + # Figure out which token to evict -- make sure we don't evict special or punc + fill_idx = self.eviction_idx_for_head( + head_idx, input_pos, apply_window=apply_window + ) + self.attn_history_num[:, head_idx, fill_idx, :].fill_(0) + self.attn_history_denom[:, head_idx, fill_idx].fill_(0) + else: + # We can fit it in the cache + fill_idx = _end_idx() + elif strategy == KVCacheFastGen.strategies.index("full"): + fill_idx = _end_idx() + else: + raise ValueError(f"Unrecognized strategy index {strategy}.") + + return fill_idx, eviction_required + + def _update(self, input_pos, k_val, v_val, input_ids=None): + n_heads = k_val.shape[1] + + is_punc = torch.isin(input_ids, self.punc_ids) + + # If fill idx is None we place value at the back (which is truncated for attention calculation anyway) + fill_indices = torch.full( + (n_heads,), + self.max_cache_length - 1, + dtype=torch.int64, + device=k_val.device, + ) + + cache_ct_incr = torch.zeros_like(fill_indices) + + for head_idx, strategy in enumerate(self.cache_strategies): + fill_idx, eviction_required = self.select_fill_idx( + strategy, head_idx, input_pos, is_punc=is_punc + ) + + if fill_idx is None: + continue + + cache_ct_incr[head_idx] = 1 + + fill_indices[head_idx] = fill_idx + if not eviction_required: + # We can't use all fill indices to bulk assign mask because some fill_indices are dummies (self.max_cache_length - 1) + self.mask[:, head_idx, :, fill_idx] = True + + # Scatter + self.fill_headwise(fill_indices, input_pos, k_val, v_val) + self.punc_mask.scatter_( + 2, fill_indices.view(1, -1, 1), is_punc.view(1, 1, 1).expand(1, n_heads, 1) + ) + + # Only update global self.num_punc once (not once per head) + # If a head keeps punc tokens, each head will have same number of punc tokens (no punc evictions) + if is_punc: + self.num_punc += 1 + + return cache_ct_incr + + def build_special_ids_mask(self, input_ids): + seq_len = input_ids.shape[-1] + special_ids_mask = torch.zeros_like(input_ids, dtype=torch.bool) + + for special_id in self.special_ids: + # Iterate over input_ids to check for the exact sub-sequence + id_len = len(special_id) + if id_len == 1: + special_ids_mask[torch.where(input_ids == special_id)[0]] = True + else: + for i in range(seq_len - id_len + 1): + if torch.equal(input_ids[i : i + id_len], special_id): + special_ids_mask[i : i + id_len] = True + return special_ids_mask + + def build_punc_ids_mask(self, input_ids): + # TODO should be on same device as model with register_buffer + if self.punc_ids.device != input_ids.device: + self.punc_ids = self.punc_ids.to(input_ids.device) + punc_ids_mask = torch.isin(input_ids, self.punc_ids) + return punc_ids_mask + + def profile_attn_heads(self, input_pos, input_ids, attn): + input_ids = input_ids.squeeze(0) + seq_len = input_ids.shape[-1] + n_heads = attn.shape[1] + + special_mask = self.build_special_ids_mask(input_ids) + special_mask_exp = special_mask.view(1, 1, -1).expand(n_heads, seq_len, seq_len) + # Store number of special tokens for later use + self.num_special = special_mask.sum() + + punc_mask = self.build_punc_ids_mask(input_ids) + self.num_punc = punc_mask.sum() + + punc_mask_exp = punc_mask.view(1, 1, -1).expand(n_heads, seq_len, seq_len) + window_mask = create_window_attention_mask( + seq_len, self.recent_window, input_ids.device + ) + + # Average of cumulative attention probs (use input_pos to normalize) + cum_attn = torch.softmax(attn, dim=-1).squeeze(0).sum(dim=1) / ( + seq_len - input_pos + ) + heavy_hitters = ( + cum_attn.topk( + # Can calculate heavy hitters based on seq_len + math.ceil(min(self.heavy_hitter_frac * seq_len, seq_len)), + dim=1, + largest=True, + ) + .indices.unsqueeze(1) + .expand(-1, seq_len, -1) + ) + + # Hybrid Strategies: + # - special + # - special + punc + # - special + punc + frequent / heavy + # - special + punc + frequent / heavy + local + # - full + special_punc = torch.logical_or(special_mask_exp, punc_mask_exp) + special_punc_heavy = special_punc.scatter(2, heavy_hitters, True) + special_punc_heavy_local = torch.logical_or(special_punc_heavy, window_mask) + + masks = torch.stack( + [ + special_mask_exp, + special_punc, + special_punc_heavy, + special_punc_heavy_local, + torch.ones_like(special_mask_exp), + ] + ) + + attn_rep = attn.expand(masks.shape[0], -1, -1, -1) + + compressed_scores = attn_rep.masked_fill(~masks, 0).sum(dim=-1).mean(dim=-1) + + # For each column, return the first row which has cost >= min_recovery_frac + cache_strategies = ( + (compressed_scores >= self.min_recovery_frac).int().argmax(dim=0) + ) + + # Take the last query's mask as the initial KV-Cache fill mask + masks_all = masks[:, :, -1, :].transpose(1, 0) + # Select mask based on self.cache_strategies + mask_optimal = masks_all.gather( + 1, cache_strategies.view(-1, 1, 1).expand(-1, -1, seq_len) + ).squeeze(1) + + return cache_strategies, special_mask, punc_mask, mask_optimal, cum_attn + + def profile_and_update(self, input_pos, input_ids, k_val, v_val, attn): + """ + Profile the attention heads to determine the optimal KV-cache allocation. + """ + assert self.is_prefill(), "Should only be profiling during prefill stage." + + input_ids = input_ids.squeeze(0) + seq_len = input_ids.shape[-1] + n_heads = attn.shape[1] + dim = k_val.shape[-1] + + # Profile cache attention heads to define strategy for each head + self.cache_strategies, special_mask, punc_mask, mask_optimal, cum_attn = ( + self.profile_attn_heads(input_pos, input_ids, attn) + ) + + # Show which strategies are selected + print([self.strategies[i] for i in self.cache_strategies.tolist()]) + + # If none of the heads selected a heavy hitter strategy, we don't need to track attention weights + self.requires_heavy_check = any( + ["heavy" in KVCacheFastGen.strategies[i] for i in self.cache_strategies] + ) + + # Put the selected items (true values from mask) to the front. Re-arrange attentions as well. + order = mask_optimal.int().argsort(dim=1, descending=True) + order_exp = order.view(1, n_heads, seq_len, 1).expand(-1, -1, -1, dim) + + k_val = k_val.gather(2, order_exp) + v_val = v_val.gather(2, order_exp) + input_pos = input_pos.unsqueeze(0).expand(n_heads, -1).gather(1, order) + self.fill_contiguous(input_pos, k_val, v_val) + + # Record number of tokens to be inserted into the cache + self.cache_cts = mask_optimal.sum(dim=1) + + # We will need to remove special tokens and punctuation from heavy hitter eviction so need to their positions. + special_mask = special_mask.view(1, -1).expand(n_heads, -1).gather(1, order) + punc_mask = punc_mask.view(1, -1).expand(n_heads, -1).gather(1, order) + self.special_mask[:, :, :seq_len] = special_mask + self.punc_mask[:, :, :seq_len] = punc_mask + + # Update mask to reflect how many items have been inserted into each head + range_mask = ( + torch.arange(seq_len, device=self.mask.device) + .view(1, -1) + .expand(n_heads, -1) + ) + self.mask[:, :, :, :seq_len] = ( + range_mask < self.cache_cts.view(-1, 1).expand(-1, seq_len) + ).view(-1, n_heads, 1, seq_len) + + if self.requires_heavy_check: + # Update attention mask to indicate which we attentions are allowed. + cum_attn = cum_attn.gather(1, order) + self.update_attn_history(cum_attn) + + +class KVCacheAnalysis(KVCache): + relevant_kwargs = [ + "max_cache_length", + "history_window_size", + "recent_window", + "attn_thresholding", + "token_ids", + "prompt_compression_strategy", + "min_recovery_frac", + "heavy_hitter_frac", + "global_tokens", + "drop_amount", + "prompt_compression_strategy", + "attn_record_freq", + "max_seq_length", + ] + + def __init__( + self, + max_batch_size, + n_heads, + head_dim, + dtype=torch.bfloat16, + cache_strategy="scissor", + **kwargs, + ): + # Never any prompt compression for full cache + full_kwargs = { + "prompt_compression_strategy": None, + "global_tokens": 0, + "max_cache_length": kwargs["max_seq_length"], + "max_seq_length": kwargs["max_seq_length"], + } + super().__init__( + max_batch_size, n_heads, head_dim, dtype, head_specific=False, **full_kwargs + ) + + # Initialize the compressed cache we want to analyze. + self.compressed = get_cache_constructor(cache_strategy=cache_strategy)[0]( + max_batch_size, + n_heads, + head_dim, + dtype, + **kwargs, + ) + + self.register_buffer( + "attention_losses", + torch.full((self.max_seq_length,), fill_value=-1, dtype=dtype), + ) + + def return_attn(self): + return self.compressed.return_attn() + + def update(self, input_pos, k_val, v_val, input_ids=None): + k, v, mask, _ = super().update(input_pos, k_val, v_val, input_ids=input_ids) + _, _, _, attn_callback = self.compressed.update( + input_pos, k_val, v_val, input_ids=input_ids + ) + + if attn_callback is not None and input_pos.shape[-1] == 1: + # This is ugly but we need to re-write callback to call this class's update_attn_history not the compressed + # This is because we need to filter the attention weights to only the tokens in the compressed cache first. + attn_callback = self.attn_history_callback() + assert attn_callback is not None + + return k, v, mask, attn_callback + + def _update(self, input_pos, k_val, v_val, input_ids=None): + # input_pos: [S], k_val: [B, H, S, D] + self.fill_contiguous(input_pos, k_val, v_val) + return input_pos.shape[-1] + + def reset(self): + super().reset() + self.compressed.reset() + self.attention_losses.fill_(-1) + + def update_attn_history(self, attn: torch.Tensor): + indices = self.compressed.pos.clone().long() + + # Global tokens will have been set to max seq length + # We need to set them back to actual global tokens + indices[:, :, : self.compressed.global_tokens] = ( + torch.arange(self.compressed.global_tokens, device=indices.device) + .view(1, 1, -1) + .expand(1, indices.shape[1], -1) + ) + indices = indices[:, :, : min(indices.shape[-1], attn.shape[-1])] + attn_compressed = attn.squeeze(2).gather(2, indices).unsqueeze(2) + self.compressed.update_attn_history(attn_compressed) + + attn_loss = (1 - attn_compressed.sum(dim=-1)).mean() + insert_idx = torch.where(self.attention_losses == -1)[0][0] + self.attention_losses[insert_idx] = attn_loss + + def compute_statistics(self, seq_len): + """ + Computes statistics about the cache. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The cache size, the number of tokens inserted, and the compression ratio. + """ + stats = super().compute_statistics(seq_len) + cutoff = torch.where(self.attention_losses == -1)[0] + if len(cutoff) > 0: + cutoff = cutoff[0] + else: + cutoff = len(self.attention_losses) + stats["attention_loss"] = (self.attention_losses[:cutoff].sum() / cutoff).item() + return stats + + +def get_cache_constructor(cache_strategy): + relevant_kwargs = None + if cache_strategy == "full": + cls = KVCacheFull + elif cache_strategy == "l2": + cls = KVCacheL2 + elif cache_strategy == "random": + cls = KVCacheRandom + elif cache_strategy == "window": + cls = KVCacheWindow + elif cache_strategy == "scissor": + cls = KVCacheScissorhands + elif cache_strategy == "fastgen": + cls = KVCacheFastGen + elif cache_strategy == "gist": + cls = KVCacheGist + elif cache_strategy.startswith("debug"): + cache_strategy = re.sub(r"debug_+", "", cache_strategy).strip() + relevant_kwargs = get_cache_constructor(cache_strategy)[1] + cls = ( + lambda max_batch_size, n_heads, head_dim, dtype, **kwargs: KVCacheAnalysis( + max_batch_size, + n_heads, + head_dim, + dtype, + cache_strategy=cache_strategy, + **kwargs, + ) + ) + else: + raise ValueError(f"Invalid cache strategy: {cache_strategy}") + + return cls, relevant_kwargs or cls.relevant_kwargs diff --git a/cache_configs/fastgen.yaml b/cache_configs/fastgen.yaml new file mode 100644 index 0000000..70bef6e --- /dev/null +++ b/cache_configs/fastgen.yaml @@ -0,0 +1,10 @@ +cache_strategy: "fastgen" +max_cache_length: [1.0] # [Fixed] Control compression ratio with min_recovery_frac +prompt_compression_strategy: "snapkv" # Won't be used. Fastgen profiles attn and inserts directly. +recent_window: 10 # Local window to consider for local strategies +history_window_size: 400 # How many past steps to consider for attention importance calculation +drop_amount: 0 # How frequently to calculate which tokens to evict (0 means we recalculate every step) +attn_thresholding: False # Whether to threshold attention scores or record raw probabilities +min_recovery_frac: 0.85 # Higher is less compression (0.85 means we choose the policy which compresses the most tokens AND recovers 85% of the full attention matrix) +heavy_hitter_frac: 0.3 # Higher is less compression for the heavy hitter strategy +recent_window: 0.3 \ No newline at end of file diff --git a/cache_configs/full.yaml b/cache_configs/full.yaml new file mode 100644 index 0000000..fd821e2 --- /dev/null +++ b/cache_configs/full.yaml @@ -0,0 +1 @@ +cache_strategy: "full" \ No newline at end of file diff --git a/cache_configs/l2.yaml b/cache_configs/l2.yaml new file mode 100644 index 0000000..02ec3b6 --- /dev/null +++ b/cache_configs/l2.yaml @@ -0,0 +1,4 @@ +cache_strategy: "l2" +prompt_compression_strategy: "l2" +global_tokens: 4 +recent_window: 10 \ No newline at end of file diff --git a/cache_configs/random.yaml b/cache_configs/random.yaml new file mode 100644 index 0000000..1a711db --- /dev/null +++ b/cache_configs/random.yaml @@ -0,0 +1,2 @@ +cache_strategy: "random" +global_tokens: 4 \ No newline at end of file diff --git a/cache_configs/scissor.yaml b/cache_configs/scissor.yaml new file mode 100644 index 0000000..fdd7641 --- /dev/null +++ b/cache_configs/scissor.yaml @@ -0,0 +1,8 @@ +cache_strategy: "scissor" +prompt_compression_strategy: "snapkv" +global_tokens: 4 +recent_window: 10 +history_window_size: 400 +drop_amount: 0 +attn_thresholding: False +attn_record_freq: 1 diff --git a/cache_configs/task_stats.csv b/cache_configs/task_stats.csv new file mode 100644 index 0000000..31c26c7 --- /dev/null +++ b/cache_configs/task_stats.csv @@ -0,0 +1,12 @@ +task,n,is_mcqa,prompt_tokens,label_tokens,n_choices +dolomites,664,False,780.5105421686746,468.89006024096386, +musique,2417,False,2469.275134464212,14.035579328959543, +qmsum,281,False,14065.02846975089,84.60854092526691, +rulercwe,500,False,3791.214,11.924400000000007, +rulerniah,500,False,3819.522,13.0, +rulerqa,500,False,3333.914,13.738, +rulervt,500,False,3847.114,13.107199999999976, +scrollsquality,2086,True,5986.950623202301,11.0, +squality,260,False,6879.084615384615,283.7625, +triviaqa,17210,False,10643.657989540965,13.0, +truthfulqa,817,True,152.84944920440637,11.0, diff --git a/cache_configs/window.yaml b/cache_configs/window.yaml new file mode 100644 index 0000000..f87981b --- /dev/null +++ b/cache_configs/window.yaml @@ -0,0 +1,3 @@ +cache_strategy: "window" +prompt_compression_strategy: "recent_global" +global_tokens: 4 diff --git a/eval.py b/eval.py index d38abf8..cd65361 100644 --- a/eval.py +++ b/eval.py @@ -5,266 +5,433 @@ # LICENSE file in the root directory of this source tree. import sys import time +import yaml +import argparse +import json +import regex as re +import contextlib +import shutil +import pandas as pd from pathlib import Path -from typing import Optional +from typing import Optional, List +from collections import defaultdict +from tqdm.auto import tqdm import torch import torch._dynamo.config import torch._inductor.config -torch._dynamo.config.automatic_dynamic_shapes = True + +from cache import add_cache_arguments, cache_compatibility, get_cache_constructor +from model import Transformer +from generation_utils import ( + add_generation_arguments, + compute_max_seq_length, + decode_one_token, + device_sync, + get_cache_stats, + prefill, + reset_caches, + setup_caches, +) +from tokenizer import encode, TokenizerInterface + + +torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.triton.unique_kernel_names = True -torch._inductor.config.epilogue_fusion = False -torch._inductor.config.triton.cudagraphs = True -torch._dynamo.config.cache_size_limit = 100000 +torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future + +default_device = "cuda" if torch.cuda.is_available() else "cpu" + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) from tokenizer import get_tokenizer +from generation_utils import load_model, generate +from task import TASK_MAPPING, AutoTask + + +def args_to_str(args): + if "debug" in args.cache_strategy: + debug_suffix = "__debug" + cache_strategy = re.sub(r"debug_+", "", args.cache_strategy).strip() + RELEVANT_CACHE_KWARGS = get_cache_constructor( + args.cache_strategy.replace("debug_", "") + )[1] + else: + cache_strategy = args.cache_strategy + debug_suffix = "" + RELEVANT_CACHE_KWARGS = get_cache_constructor(cache_strategy)[1] + + def process_num(n): + # Return integer floats as "1" not 1.0 + # Otherwise, no op + if type(n) == float and int(n) == n: + return int(n) + return n + + return ( + "__".join( + sorted( + [ + f"{k}=" + ",".join([str(process_num(m)) for m in v]) + if type(v) == list + else f"{k}={process_num(v)}" + for k, v in vars(args).items() + if k in RELEVANT_CACHE_KWARGS + ] + ) + ) + + debug_suffix + ) -from model import Transformer -try: - import lm_eval - lm_eval_available = True -except: - lm_eval_available = False - -from generate import _load_model, encode_tokens, model_forward - -if lm_eval_available: - try: # lm_eval version 0.4 - from lm_eval.models.huggingface import HFLM as eval_wrapper - from lm_eval.tasks import get_task_dict - from lm_eval.evaluator import evaluate - except: #lm_eval version 0.3 - from lm_eval import base - from lm_eval import tasks - from lm_eval import evaluator - eval_wrapper=base.BaseLM - get_task_dict=tasks.get_task_dict - evaluate=evaluator.evaluate - - -def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( +def run_task( + args: argparse.Namespace, + task: AutoTask, model: Transformer, - prompt: torch.Tensor, - max_new_tokens: int, - max_seq_length: Optional[int] = None, + tokenizer: TokenizerInterface, + is_chat: bool = False, + profile: Optional[Path] = None, + feed_long_prompts=False, + device=default_device, + cache_kwargs: dict = {}, + use_tp: bool = False, + rank: int = None, + terminator_ids: List[int] = None, ): - """ - Sets up model cache and does some bookkeeping calculations for prompt, input_pos and max_seq_length - that are needed for prefill or model_forward - - Args: - model (LLaMA): The model whose cache gets set up - prompt (torch.Tensor): Tensor of shape (T) with indices of the prompt sequence. - max_new_tokens (int): The desired maximum number of new tokens that can be generated. - max_seq_length (Optional[int], optional): The maximum sequence length allowed. - - Returns: - seq (torch.Tensor): prompt but padded with zeros to size max_seq_length - input_pos (torch.Tensor): tensor of integers in increasing order - max_seq_length (int): The maximum sequence length allowed, updated based on other numbers - """ - T = prompt.size(0) - T_new = T + max_new_tokens - if max_seq_length is None: - max_seq_length = min(T_new, model.config.block_size) - - device, dtype = prompt.device, prompt.dtype - # create an empty tensor of the expected final shape and fill in the current tokens - empty = torch.empty(T_new, dtype=dtype, device=device) - empty[:T] = prompt - seq = empty - input_pos = torch.arange(0, T, device=device) - - with torch.device(device): - model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) - - return seq, input_pos, max_seq_length - -class GPTFastEvalWrapper(eval_wrapper): - """ - A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library. - """ - def __init__( - self, - model: Transformer, - tokenizer, - max_seq_length: Optional[int]=None, - ): - super().__init__() - self._model = model - self._tokenizer = tokenizer - self._device = torch.device('cuda') - self._max_seq_length = 2048 if max_seq_length is None else max_seq_length - - @property - def eot_token_id(self): - return self._tokenizer.eos_id() - - @property - def max_length(self): - return self._max_seq_length - - @property - def max_gen_toks(self): - return 50 - - @property - def batch_size(self): - return 1 - - @property - def device(self): - return self._device - - def tok_encode(self, string: str, **kwargs): - encoded = encode_tokens(self._tokenizer, - string, bos=True, device=self._device) - # encoded is a pytorch tensor, but some internal logic in the - # eval harness expects it to be a list instead - # TODO: verify this for multi-batch as well - encoded = encoded.tolist() - return encoded - - def tok_decode(self, tokens): - decoded = self._tokenizer.decode(tokens) - return decoded - - def _model_call(self, inps): - # TODO: make batches work - inps = inps.squeeze(0) - - max_new_tokens = 1 - seq, input_pos, max_seq_length = \ - setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( - self._model, - inps, - max_new_tokens, - self.max_length, + aggregate_metrics = defaultdict(list) + predictions = [] + all_probs = [] + task_metrics = {} + + test = task.get_test() + prompts = test["prompt"] + + inputs = [ + encode(tokenizer, prompt, device="cpu", is_chat=is_chat) + for prompt in tqdm(prompts, desc="Encoding Prompts") + ] + + _, max_seq_length = compute_max_seq_length(model, inputs, task.max_tokens) + + setup_caches(model, tokenizer, device, max_seq_length, cache_kwargs.copy()) + + for i in tqdm(range(len(inputs))): + input = inputs[i].to(device) + prompt_length = input.size(0) + + max_new_tokens = min(task.max_tokens, max_seq_length - prompt_length) + assert max_new_tokens > 0, f"Prompt too long for model: {prompt_length}" + + device_sync(device=device) # MKG + t0 = time.perf_counter() + + if not profile or (use_tp and rank != 0): + prof = contextlib.nullcontext() + else: + torch.profiler._utils._init_for_cuda_graphs() + prof = torch.profiler.profile() + with prof: + y, probs = generate( + model, + input, + max_new_tokens=max_new_tokens, + terminator_ids=terminator_ids, + feed_long_prompts=feed_long_prompts, + ) + if hasattr(prof, "export_chrome_trace"): + if use_tp: + prof.export_chrome_trace(f"{profile}_rank_{rank}.json") + else: + prof.export_chrome_trace(f"{profile}.json") + device_sync(device=device) # MKG + t = time.perf_counter() - t0 + + tokens_generated = y.size(0) - prompt_length + tokens_sec = tokens_generated / t + aggregate_metrics["tokens_per_sec"].append(tokens_sec) + aggregate_metrics["num_toks"].append(tokens_generated) + + # Reset Counters for KV Cache + cache_stats = get_cache_stats(model, prompt_length, tokens_generated) + for k, v in cache_stats.items(): + aggregate_metrics[k].append(v) + + # Decode: remove EoT and prompt + end = y.size(0) + if y[-1] in terminator_ids: + end = -1 + pred = tokenizer.decode(y[prompt_length:end].tolist()) + + if args.debug: + print(f"Prediction: {pred}") + + predictions.append(pred) + if task.requires_logits: + all_probs.append( + {k: v for k, v in zip(tokenizer.get_vocab(), probs[0].tolist())} ) - x = seq.index_select(0, input_pos).view(1, -1) - logits = model_forward(self._model, x, input_pos) - return logits - - def _model_generate(self, context, max_length, eos_token_id): - raise Exception('unimplemented') + reset_caches(model) -@torch.no_grad() -def eval( - model: Transformer, - tokenizer, - tasks: list = ["hellaswag"], - limit: Optional[int] = None, - max_seq_length: Optional[int] = None, -) -> dict: - """ - Evaluates a language model on a specified task using the lm-evaluation-harness library. - - Args: - model (Transformer): The pre-trained language model to evaluate. - tokenizer: The tokenizer to use for encoding/decoding text. - task (str): The name of the evaluation task to perform. - limit (Optional[int]): The maximum number of samples to evaluate (None for all available). - max_seq_length (Optional[int]): The maximum sequence length allowed for input text. - - Returns: - eval_results (dict): A dictionary of evaluation results for the specified task(s). - """ - model_eval_wrapper = GPTFastEvalWrapper( - model, - tokenizer, - max_seq_length, + print( + f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}" ) + max_mem_gb = torch.cuda.max_memory_reserved() / 1e9 + print(f"Memory used: {max_mem_gb} GB") + task_metrics["max_memory_gb"] = max_mem_gb - try: - lm_eval.tasks.initialize_tasks() - except: - pass + for k, v in aggregate_metrics.items(): + task_metrics[k] = sum(v) / len(v) - if 'hendrycks_test' in tasks: - tasks.remove('hendrycks_test') - tasks += [x for x in lm_eval.tasks.hendrycks_test.create_all_tasks().keys()] - task_dict = get_task_dict(tasks) + if task.requires_logits: + metrics = task.test_metrics(all_probs) + else: + metrics = task.test_metrics(predictions) - eval_results = evaluate( - model_eval_wrapper, - task_dict, - limit=limit, - ) - return eval_results + pred_df = pd.DataFrame({"prompt": prompts, "prediction": predictions}) + + for k, v in metrics.items(): + if type(v) == dict: + for kk, vv in v.items(): + task_metrics[f"{k}_{kk}"] = vv + else: + task_metrics[k] = v + return task_metrics, pred_df def main( - checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/lit_model.pth"), - compile: bool = False, - tasks: list = ["hellaswag"], - limit: Optional[int] = None, - max_seq_length: Optional[int] = None, + args: argparse.Namespace, + tasks: List[str], + debug: bool = False, + checkpoint_path: Path = Path( + "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" + ), + profile: Optional[Path] = None, + compile=True, + feed_long_prompts=False, + device=default_device, + cache_kwargs: dict = {}, + out_dir: Path = None, ) -> None: - """Evaluates model on a task from the `lm-evaluation-harness` library. - - Args: - checkpoint_path (Path): The path to the model checkpoint file to load. - compile (bool): Whether or not to compile the model for optimization. - task (Optional[str]): The name of the evaluation task or a list of tasks to perform. - limit (Optional[int]): The maximum number of samples to evaluate (None for all available). - max_seq_length (Optional[int]): The maximum sequence length allowed for input text. - - """ - + """Generates text samples based on a pre-trained Transformer model and tokenizer.""" assert checkpoint_path.is_file(), checkpoint_path tokenizer_path = checkpoint_path.parent / "tokenizer.model" - assert tokenizer_path.is_file(), str(tokenizer_path) - - device = 'cuda' + if not tokenizer_path.is_file(): + # If there's no tokenizer.model, try to load the tokenizer from the parent directory + # NOTE: We assume the tokenizer in the parent directory is compatible with huggingface transformers + tokenizer_path = checkpoint_path.parent + + global print + from tp import maybe_init_dist + + rank = maybe_init_dist() + use_tp = rank is not None + if use_tp: + if rank != 0: + # only print on rank 0 + print = lambda *args, **kwargs: None + + print(f"Using device={device}") precision = torch.bfloat16 + is_chat = ( + "chat" in str(checkpoint_path).lower() + or "instruct" in str(checkpoint_path).lower() + ) print("Loading model ...") t0 = time.time() - model = _load_model(checkpoint_path, device, precision, False) + model = load_model(checkpoint_path, device, precision, use_tp) - torch.cuda.synchronize() - print(f"Time to load model: {time.time() - t0:.02f} seconds.") + device_sync(device=device) # MKG + print(f"Time to load model: {time.time() - t0:.02f} seconds") - model.eval() + tokenizer = get_tokenizer(tokenizer_path, checkpoint_path, is_chat=is_chat) - tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) + if cache_kwargs["cache_strategy"] == "fastgen": + # We need to pass the special and punctuation token ids to the cache via cache_kwargs + cache_kwargs["token_ids"] = { + "special": tokenizer.special_ids(), + "punctuation": tokenizer.punctuation_ids(), + } + + terminator_ids = tokenizer.get_terminator_ids() torch.manual_seed(1234) if compile: - global model_forward - model_forward = torch.compile(model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True) - torch._inductor.config.coordinate_descent_tuning = True - - t1 = time.time() - result = eval( - model, - tokenizer, - tasks, - limit, - max_seq_length, + global decode_one_token, prefill + decode_one_token = torch.compile( + decode_one_token, mode="reduce-overhead", fullgraph=True + ) + prefill = torch.compile(prefill, fullgraph=True, dynamic=True) + + task_kwargs = { + "model_max_length": model.config.max_length, + "num_samples": args.num_samples, + "tokenizer": tokenizer.encode_prompt if is_chat else tokenizer.encode, + } + if tasks == ["all"]: + # Evaluate all tasks + tasks = list(TASK_MAPPING.keys()) + eval_tasks = {task: AutoTask.from_name(task, **task_kwargs) for task in tasks} + + task_metrics = defaultdict(dict) + args_fn = out_dir / "args.json" + all_out_fn = out_dir / "all_metrics.json" + for task_name, task in eval_tasks.items(): + print(f"Running task {task_name} ...") + task_out_fn = out_dir / f"{task_name}_metrics.json" + pred_out_fn = out_dir / f"{task_name}_predictions.csv" + if task_out_fn.exists() and not cache_kwargs["overwrite"]: + print(f"Task {task_name} already evaluated. Skipping.") + with open(task_out_fn, "r") as fd: + task_metrics[task_name] = json.load(fd) + else: + task_metrics[task_name], predictions = run_task( + args, + task, + model, + tokenizer, + is_chat, + profile, + feed_long_prompts, + device, + cache_kwargs, + use_tp, + rank, + terminator_ids, + ) + + predictions.to_csv(pred_out_fn, index=False) + + if debug: + print(f"Results for {task_name}:") + print(task_metrics[task_name]) + + with open(task_out_fn, "w") as fd: + print(f"Saving results for {task_name} to {task_out_fn}") + json.dump(task_metrics[task_name], fd, indent=2) + + if not args_fn.exists(): + # Only save args once and only save if we've gotten through a full eval and are ready to dump metrics + with open(args_fn, "w") as fd: + # Convert Path objects to strings + cache_kwargs_json = { + k: str(v) if isinstance(v, Path) else v + for k, v in cache_kwargs.items() + } + json.dump(cache_kwargs_json, fd, indent=2) + + with open(all_out_fn, "w") as fd: + json.dump(task_metrics, fd, indent=2) + + +def setup(args) -> Path: + out_dir = ( + Path(__file__).parent + / "results" + / args.checkpoint_path.parent.stem + / args.cache_strategy + / args_to_str(args) + ) + + print(f"Saving to {out_dir}") + # Make out_dir and don't err out if it already exists + if out_dir.exists(): + print(f"Output directory {out_dir} already exists.") + if args.overwrite: + print(f"Removing {out_dir}.") + shutil.rmtree(out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + cache_compatibility(args) + + for k, v in vars(args).items(): + print(f"{k} -> {v}") + + return out_dir + + +def add_eval_args(parser): + parser.add_argument( + "--tasks", + type=str, + nargs="+", + default=["truthfulqa"], + choices=list(TASK_MAPPING.keys()), + help="List of tasks to be evaluated.", + ) + + parser.add_argument( + "--debug", + default=False, + action="store_true", + help="Debug mode uses first 10 examples in dataset.", + ) + + parser.add_argument( + "--num_samples", + type=int, + default=None, + help="Number of examples to sample for evaluation. Defaults to None, which uses the full dataset.", + ) + + parser.add_argument( + "--overwrite", + default=False, + action="store_true", + help="Whether to over-write existing results if they exist.", + ) + + parser.add_argument( + "--cache_config", + type=str, + default=None, + help="Name of YAML file in ./cache_configs.", + ) + + +def merge_cache_config(args): + if not args.cache_config: + return args + # Get parent directory of current file + if not args.cache_config.endswith(".yaml"): + args.cache_config = args.cache_config + ".yaml" + yaml_fn = Path(__file__).parent / "cache_configs" / args.cache_config + assert yaml_fn.exists(), f"Cache config file {yaml_fn} does not exist." + with open(yaml_fn, "r") as f: + cache_kwargs = yaml.safe_load(f) + # Over-write args with cache_kwargs + args = argparse.Namespace(**{**vars(args), **cache_kwargs}) + return args + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Evaluation script for different KV-Cache Compression Algorithms." ) - print(f"Time to run eval: {time.time() - t1:.02f} seconds.") - print(f"For model {checkpoint_path}") - for task, res in result["results"].items(): - print(f"{task}: {res}") + add_eval_args(parser) + add_generation_arguments(parser) + add_cache_arguments(parser) -if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser(description='Your CLI description.') + args = merge_cache_config(parser.parse_args()) - parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/lit_model.pth"), help='Model checkpoint path.') - parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') - parser.add_argument('--tasks', nargs='+', type=str, default=["hellaswag"], help='list of lm-eluther tasks to evaluate usage: --tasks task1 task2') - parser.add_argument('--limit', type=int, default=None, help='number of samples to evalulate') - parser.add_argument('--max_seq_length', type=int, default=None, help='maximum length sequence to evaluate') + out_dir = setup(args) - args = parser.parse_args() main( - Path(args.checkpoint_path), args.compile, args.tasks, args.limit, args.max_seq_length, + args, + args.tasks, + args.debug, + args.checkpoint_path, + args.profile, + args.compile, + args.feed_long_prompts, + args.device, + cache_kwargs=vars(args), + out_dir=out_dir, ) diff --git a/eval_multi.py b/eval_multi.py new file mode 100644 index 0000000..93625f2 --- /dev/null +++ b/eval_multi.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import sys +import argparse +from pathlib import Path + +import torch +import torch._dynamo.config +import torch._inductor.config + + +from cache import add_cache_arguments +from generation_utils import add_generation_arguments + +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True +torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future + +default_device = "cuda" if torch.cuda.is_available() else "cpu" + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from eval import add_eval_args, setup, merge_cache_config, main as eval_main + + +HPARAMS = { + "max_cache_length": [[8192], [4096], [2048], [1024], [512], [256], [128]], + "min_recovery_frac": [0.5, 0.6, 0.7, 0.8, 0.9, 0.95], +} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Sweep a hyper-parameter for a KV-Cache Compression Algorithms." + ) + + parser.add_argument( + "--hparam", + default="max_cache_length", + help="The hyper-parameter to sweep.", + ) + + add_eval_args(parser) + add_generation_arguments(parser) + add_cache_arguments(parser) + + args = merge_cache_config(parser.parse_args()) + + assert args.hparam in HPARAMS, f"Set {args.hparam} in HPARAMS dictionary first." + + for v in HPARAMS[args.hparam]: + # Copy the args object to avoid modifying the original + exp_args = argparse.Namespace(**vars(args)) + print(f"Setting {args.hparam} to {v}") + setattr(exp_args, args.hparam, v) + + out_dir = setup(exp_args) + + eval_main( + args, + args.tasks, + args.debug, + args.checkpoint_path, + args.profile, + args.compile, + args.feed_long_prompts, + args.device, + cache_kwargs=vars(exp_args), + out_dir=out_dir, + ) diff --git a/generate.py b/generate.py index c58a224..db8ce54 100644 --- a/generate.py +++ b/generate.py @@ -3,284 +3,65 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import itertools import sys import time +import contextlib +import json from pathlib import Path -from typing import Optional, Tuple +from typing import Optional import torch import torch._dynamo.config import torch._inductor.config -def device_sync(device): - if "cuda" in device: - torch.cuda.synchronize(device) - elif ("cpu" in device) or ("mps" in device): - pass - else: - print(f"device={device} is not yet suppported") - +from cache import add_cache_arguments +from generation_utils import ( + add_generation_arguments, + compute_max_seq_length, + decode_one_token, + device_sync, + prefill, +) torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.triton.unique_kernel_names = True -torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future +torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future -default_device = 'cuda' if torch.cuda.is_available() else 'cpu' +default_device = "cuda" if torch.cuda.is_available() else "cpu" # support running without installing as a package wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) -from model import Transformer -from tokenizer import get_tokenizer - -def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization - q = torch.empty_like(probs_sort).exponential_(1) - return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) - -def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): - logits = logits / max(temperature, 1e-5) - - if top_k is not None: - v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - pivot = v.select(-1, -1).unsqueeze(-1) - logits = torch.where(logits < pivot, -float("Inf"), logits) - probs = torch.nn.functional.softmax(logits, dim=-1) - return probs - -def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): - probs = logits_to_probs(logits[0, -1], temperature, top_k) - idx_next = multinomial_sample_one_no_sync(probs) - return idx_next, probs - -def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: - # input_pos: [B, S] - logits = model(x, input_pos) - return sample(logits, **sampling_kwargs)[0] - -def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: - # input_pos: [B, 1] - assert input_pos.shape[-1] == 1 - logits = model(x, input_pos) - return sample(logits, **sampling_kwargs) - -def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): - new_tokens, new_probs = [], [] - for i in range(num_new_tokens): - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here - next_token, next_prob = decode_one_token( - model, cur_token, input_pos, **sampling_kwargs - ) - input_pos += 1 - new_tokens.append(next_token.clone()) - callback(new_tokens[-1]) - new_probs.append(next_prob.clone()) - cur_token = next_token.view(1, -1) - - return new_tokens, new_probs - - -def model_forward(model, x, input_pos): - return model(x, input_pos) - -def speculative_decode( - model: Transformer, - draft_model: Transformer, - cur_token: torch.Tensor, - input_pos: int, - speculate_k: int, - **sampling_kwargs -) -> torch.Tensor: - # draft model inference sequentially - device = cur_token.device - orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device) - draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs) - - draft_tokens = torch.cat(draft_tokens) - # parallel inference on target model using draft tokens - target_logits = model_forward( - model, - torch.cat([cur_token.view(1), draft_tokens]).view(1, -1), - torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device) - ) - target_probs = logits_to_probs(target_logits[0], **sampling_kwargs) - draft_probs = torch.stack(draft_probs) - # q: target prob, p: draft prob - # q >= p: always accept draft token - # q < p: q/p prob to accept draft token - p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens] - q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens] - accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k]/ p) - rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero() - - if rejected_locations.shape[0] == 0: # All draft tokens have been accepted - accept_length = speculate_k + 1 - last_token = multinomial_sample_one_no_sync(target_probs[-1]) - # fill last token into draft model - model_forward( - draft_model, - draft_tokens[-1].view(1, -1), - orig_input_pos + speculate_k, - ) - return torch.cat([draft_tokens, last_token]) - else: - accept_length = rejected_locations[0].item() - p = draft_probs[accept_length] - q = target_probs[accept_length] - new = q - p - new = torch.where(new > 0, new, 0.0) - new = new / new.sum() - next_token = multinomial_sample_one_no_sync(new) - return torch.cat([draft_tokens[:accept_length], next_token]) - -@torch.no_grad() -def generate( - model: Transformer, - prompt: torch.Tensor, - max_new_tokens: int, - *, - interactive: bool, - draft_model: Transformer, - speculate_k: Optional[int] = 8, - callback = lambda x: x, - **sampling_kwargs -) -> torch.Tensor: - """ - Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. - """ - - is_speculative = draft_model is not None - # create an empty tensor of the expected final shape and fill in the current tokens - T = prompt.size(0) - T_new = T + max_new_tokens - if interactive: - max_seq_length = 350 - else: - max_seq_length = min(T_new, model.config.block_size) - - device, dtype = prompt.device, prompt.dtype - max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length - with torch.device(device): - model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) - if is_speculative and draft_model is not model: - draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) - - # create an empty tensor of the expected final shape and fill in the current tokens - empty = torch.empty(T_new, dtype=dtype, device=device) - empty[:T] = prompt - seq = empty - input_pos = torch.arange(0, T, device=device) - - next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs).clone() - if is_speculative: - prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs) - seq[T] = next_token - - input_pos = torch.tensor([T], device=device, dtype=torch.int) - accept_counts = [0] * (speculate_k + 1) - - if is_speculative: - input_pos = input_pos.item() # for speculative decoding easier to keep on host - while input_pos < T_new - 1: - cur_token = next_token.view(()) - - next_tokens = speculative_decode( - model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs - ) - - accept_counts[len(next_tokens) - 1] += 1 - num_added = min(T_new - input_pos - 1, len(next_tokens)) - seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[: num_added] - for i in next_tokens[: num_added,]: - callback(i) - input_pos = input_pos + num_added - next_token = next_tokens[-1] - else: - generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) - seq[T + 1:] = torch.cat(generated_tokens) - - generate_stats = { - 'accept_counts': accept_counts - } - return seq, generate_stats - -def encode_tokens(tokenizer, string, bos=True, device=default_device): - tokens = tokenizer.encode(string) - if bos: - tokens = [tokenizer.bos_id()] + tokens - return torch.tensor(tokens, dtype=torch.int, device=device) - -def _load_model(checkpoint_path, device, precision, use_tp): - use_cuda = 'cuda' in device - with torch.device('meta'): - model = Transformer.from_name(checkpoint_path.parent.name) - - if "int8" in str(checkpoint_path): - print("Using int8 weight-only quantization!") - from quantize import WeightOnlyInt8QuantHandler - simple_quantizer = WeightOnlyInt8QuantHandler(model) - model = simple_quantizer.convert_for_runtime() - - if "int4" in str(checkpoint_path): - print("Using int4 weight-only quantization!") - path_comps = checkpoint_path.name.split(".") - groupsize = int(path_comps[-2][1:]) - from quantize import WeightOnlyInt4QuantHandler - simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) - model = simple_quantizer.convert_for_runtime() - - checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) - if "model" in checkpoint and "stories" in str(checkpoint_path): - checkpoint = checkpoint["model"] - model.load_state_dict(checkpoint, assign=True) +from tokenizer import get_tokenizer, encode +from generation_utils import generate, load_model, get_model_size, setup_caches +from cache import add_cache_arguments, cache_compatibility - if use_tp: - from tp import apply_tp - print("Applying tensor parallel to model ...") - apply_tp(model) - - model = model.to(device=device, dtype=precision) - return model.eval() - -def _get_model_size(model): - model_size = 0 - for name, child in model.named_children(): - if not isinstance(child, torch.nn.Embedding): - model_size += sum( - [ - p.numel() * p.dtype.itemsize - for p in itertools.chain(child.parameters(), child.buffers()) - ] - ) - return model_size - -B_INST, E_INST = "[INST]", "[/INST]" def main( prompt: str = "Hello, my name is", - interactive: bool = False, - num_samples: int = 5, max_new_tokens: int = 100, - top_k: int = 200, - temperature: float = 0.8, - checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), + checkpoint_path: Path = Path( + "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" + ), compile: bool = True, - compile_prefill: bool = False, + feed_long_prompts: bool = False, profile: Optional[Path] = None, - draft_checkpoint_path: Optional[Path] = None, - speculate_k: int = 5, device=default_device, + cache_kwargs: dict = {}, ) -> None: - """Generates text samples based on a pre-trained Transformer model and tokenizer. - """ + """Generates text samples based on a pre-trained Transformer model and tokenizer.""" assert checkpoint_path.is_file(), checkpoint_path tokenizer_path = checkpoint_path.parent / "tokenizer.model" - assert tokenizer_path.is_file(), str(tokenizer_path) + if not tokenizer_path.is_file(): + # If there's no tokenizer.model, try to load the tokenizer from the parent directory + # NOTE: We assume the tokenizer in the parent directory is compatible with huggingface transformers + tokenizer_path = checkpoint_path.parent global print from tp import maybe_init_dist + rank = maybe_init_dist() use_tp = rank is not None if use_tp: @@ -290,147 +71,125 @@ def main( print(f"Using device={device}") precision = torch.bfloat16 - is_speculative = draft_checkpoint_path is not None - is_chat = "chat" in str(checkpoint_path) + is_chat = ( + "chat" in str(checkpoint_path).lower() + or "instruct" in str(checkpoint_path).lower() + ) print("Loading model ...") t0 = time.time() - model = _load_model(checkpoint_path, device, precision, use_tp) + model = load_model(checkpoint_path, device, precision, use_tp) - if is_speculative: - draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp) - else: - draft_model = None - - device_sync(device=device) # MKG + device_sync(device=device) # MKG print(f"Time to load model: {time.time() - t0:.02f} seconds") - tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) - - encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) - prompt_length = encoded.size(0) + tokenizer = get_tokenizer(tokenizer_path, checkpoint_path, is_chat=is_chat) - torch.manual_seed(1234) - model_size = _get_model_size(model) - if compile: - if is_speculative and use_tp: # and ("cuda" in device): - torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case + gist_token_id = tokenizer.gist_token_id() if hasattr(tokenizer, "gist_token_id") else None - if is_speculative: - global model_forward, logits_to_prob - model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True) + inputs = [encode(tokenizer, prompt, device=device, is_chat=is_chat)] - global decode_one_token, prefill - decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) + terminator_ids = tokenizer.get_terminator_ids() - # Uncomment to squeeze more perf out of prefill - if compile_prefill: - prefill = torch.compile(prefill, fullgraph=True, dynamic=True) + torch.manual_seed(1234) + model_size = get_model_size(model) + print(f"{model_size / 1e9:.02f} billion parameters in model.") + if compile: + global decode_one_token, prefill + decode_one_token = torch.compile( + decode_one_token, mode="reduce-overhead", fullgraph=True + ) + prefill = torch.compile(prefill, fullgraph=True, dynamic=True) aggregate_metrics = { - 'tokens_per_sec': [], - 'accept_counts': [], + "tokens_per_sec": [], } - start = -1 if compile else 0 - - for i in range(start, num_samples): - device_sync(device=device) # MKG - if i >= 0 and interactive: - prompt = input("What is your prompt? ") - if is_chat: - prompt = f"{B_INST} {prompt.strip()} {E_INST}" - encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) - - if interactive and i >= 0: - buffer = [] - period_id = tokenizer.encode('.')[0] - done_generating = False - def callback(x): - nonlocal done_generating - if done_generating: - return - buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) - if x.item() == tokenizer.eos_id(): - done_generating = True - if len(buffer) == 4 or done_generating: - print(''.join(buffer), end='', flush=True) - buffer.clear() - # print(, end='', flush=True) - else: - callback = lambda x : x - t0 = time.perf_counter() - import contextlib - if (i != num_samples - 1 or not profile) or (use_tp and rank != 0): - prof = contextlib.nullcontext() - else: - torch.profiler._utils._init_for_cuda_graphs() - prof = torch.profiler.profile() - with prof: - y, metrics = generate( - model, - encoded, - max_new_tokens, - draft_model=draft_model, - speculate_k=speculate_k, - interactive=interactive, - callback=callback, - temperature=temperature, - top_k=top_k, - ) - aggregate_metrics['accept_counts'].append(metrics['accept_counts']) - if i == -1: - print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") - continue - if hasattr(prof, "export_chrome_trace"): - if use_tp: - prof.export_chrome_trace(f"{profile}_rank_{rank}.json") - else: - prof.export_chrome_trace(f"{profile}.json") - device_sync(device=device) # MKG - t = time.perf_counter() - t0 - - if not interactive: - print(tokenizer.decode(y.tolist())) + + device_sync(device=device) # MKG + + max_prompt_length, max_seq_length = compute_max_seq_length( + model, inputs, max_new_tokens + ) + max_new_tokens = min(max_new_tokens, max_seq_length - max_prompt_length) + setup_caches(model, tokenizer, inputs[0].device, max_seq_length, cache_kwargs) + t0 = time.perf_counter() + + if (not profile) or (use_tp and rank != 0): + prof = contextlib.nullcontext() + else: + torch.profiler._utils._init_for_cuda_graphs() + prof = torch.profiler.profile() + + with prof: + y, _ = generate( + model, + inputs[0], + max_new_tokens=max_new_tokens, + terminator_ids=terminator_ids, + gist_token_id=gist_token_id, + feed_long_prompts=feed_long_prompts, + ) + print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + if hasattr(prof, "export_chrome_trace"): + if use_tp: + prof.export_chrome_trace(f"{profile}_rank_{rank}.json") else: - print() - tokens_generated = y.size(0) - prompt_length - tokens_sec = tokens_generated / t - aggregate_metrics['tokens_per_sec'].append(tokens_sec) - print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") - print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") + prof.export_chrome_trace(f"{profile}.json") + device_sync(device=device) # MKG + t = time.perf_counter() - t0 + + print(tokenizer.decode(y.tolist())) + tokens_generated = y.size(0) - max_prompt_length + tokens_sec = tokens_generated / t + aggregate_metrics["tokens_per_sec"].append(tokens_sec) + print(f"Time for inference: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") + print(f"Tokens generated: {tokens_generated}") + print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") print("==========") - if is_speculative: - counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])] - acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated] - print(f"Acceptance probs: {acceptance_probs}") - print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}") - print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}") + print( + f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}" + ) print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") -if __name__ == '__main__': +if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Your CLI description.') - - parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') - parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') - parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') - parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') - parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') - parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') - parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') - parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') - parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') - parser.add_argument('--profile', type=Path, default=None, help='Profile path.') - parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.') - parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.') - parser.add_argument('--device', type=str, default=default_device, help='Device to use') + + parser = argparse.ArgumentParser( + description="Run Simple Single Prompt Generation (for development and debugging purposes)." + ) + parser.add_argument( + "--prompt", + type=str, + default="long_prompt_short_output.json", + help="Input prompt. If it ends in .json, we will load the prompt from the ./prompts dir.", + ) + parser.add_argument( + "--max_new_tokens", type=int, default=512, help="Maximum number of new tokens." + ) + + add_generation_arguments(parser) + add_cache_arguments(parser) args = parser.parse_args() + + if args.prompt.endswith(".json"): + prompt_fn = Path(__file__).resolve().parent / "prompts" / args.prompt + with open(prompt_fn) as fd: + args.prompt = json.load(fd) + + cache_compatibility(args) + main( - args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, - args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path, - args.speculate_k, args.device + args.prompt, + args.max_new_tokens, + args.checkpoint_path, + args.compile, + args.feed_long_prompts, + args.profile, + args.device, + cache_kwargs=vars(args), ) + diff --git a/generation_utils.py b/generation_utils.py new file mode 100644 index 0000000..f033db5 --- /dev/null +++ b/generation_utils.py @@ -0,0 +1,378 @@ +import itertools +from typing import Optional, Tuple +from pathlib import Path + +import torch +import torch._dynamo.config +import torch._inductor.config + +import argparse +from model import Transformer, find_multiple +from tokenizer import TokenizerInterface + + +default_device = "cuda" if torch.cuda.is_available() else "cpu" + + +def add_generation_arguments(parser: argparse.ArgumentParser): + group = parser.add_argument_group("generation_args") + # Generation hparams + group.add_argument( + "--checkpoint_path", + type=Path, + default=Path(__file__).resolve().parent + / "checkpoints/Qwen/Qwen2-1.5B-Instruct/model.pth", + help="Model checkpoint path.", + ) + + group.add_argument("--profile", type=Path, default=None, help="Profile path.") + + group.add_argument( + "--compile", action="store_true", help="Whether to compile the model." + ) + + group.add_argument( + "--device", type=str, default=default_device, help="Device to use" + ) + + +def compute_max_seq_length(model, prompt_lens, max_new_tokens) -> int: + max_prompt_length = max(len(prompt_lens[i]) for i in range(len(prompt_lens))) + max_seq_length = max_prompt_length + max_new_tokens + if max_seq_length > model.config.block_size: + print( + f"Warning: The longest prompt puts the desired max_seq_length at {max_seq_length}, which is greater than models max of {model.config.block_size}." + ) + print(f"Setting to model's max_seq_length of {model.config.block_size}.") + max_seq_length = model.config.block_size + print(f"Maximum context length of {max_seq_length} tokens.") + return max_prompt_length, max_seq_length + + +def device_sync(device): + if "cuda" in device: + torch.cuda.synchronize(device) + elif ("cpu" in device) or ("mps" in device): + pass + else: + print(f"device={device} is not yet suppported") + + +def multinomial_sample_one_no_sync( + probs_sort, +): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def sample( + logits: torch.Tensor, + next_token: torch.Tensor = None, + temperature: float = 1.0, + top_k: Optional[int] = None, +): + probs = logits_to_probs(logits[0, -1], temperature, top_k) + if next_token is None: + idx_next = multinomial_sample_one_no_sync(probs) + else: + idx_next = next_token + return idx_next, probs + + +def greedy(logits, next_token): + probs = torch.nn.functional.softmax(logits[0, -1], dim=-1) + if next_token is None: + idx_next = torch.argmax(probs, keepdim=True).to(dtype=torch.int) + else: + idx_next = next_token + return idx_next, probs + + +def prefill( + model: Transformer, + x: torch.Tensor, + input_pos: torch.Tensor, + next_token: torch.Tensor = None, + gist_token_id: Optional[int] = -1, + **sampling_kwargs, +) -> torch.Tensor: + # input_pos: [B, S] + causal_mask = ( + torch.tril(torch.ones(len(input_pos), len(input_pos), dtype=torch.bool)) + .unsqueeze(0) + .unsqueeze(0) + .to(x.device) + ) + if gist_token_id is not None: + gist_token_positions = torch.stack(torch.where(x == gist_token_id)).T + for position in gist_token_positions: + causal_mask[position[0], :, position[1] + 1:, :position[1]] = False + + logits = model(x, input_pos, mask=causal_mask) + return greedy(logits, next_token) + + +def decode_one_token( + model: Transformer, + x: torch.Tensor, + input_pos: torch.Tensor, + next_token: torch.Tensor = None, + **sampling_kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + # input_pos: [B, 1] + assert input_pos.shape[-1] == 1 + logits = model(x, input_pos) + return greedy(logits, next_token=next_token) + + +def decode_n_tokens( + model: Transformer, + cur_token: torch.Tensor, + input_pos: torch.Tensor, + num_new_tokens: int, + terminator_ids: Optional[list] = None, + prefix: Optional[torch.Tensor] = None, + **sampling_kwargs, +): + new_tokens, new_probs = [], [] + for i in range(num_new_tokens): + with torch.backends.cuda.sdp_kernel( + enable_flash=False, enable_mem_efficient=False, enable_math=True + ): # Actually better for Inductor to codegen attention here + teacher_force = prefix is not None and i < len(prefix) + next_token = prefix[i].view(1) if teacher_force else None + next_token, next_prob = decode_one_token( + model, cur_token, input_pos, next_token=next_token, **sampling_kwargs + ) + + new_tokens.append(next_token.clone()) + new_probs.append(next_prob.clone()) + + if terminator_ids and next_token in terminator_ids and not teacher_force: + break + + input_pos += 1 + cur_token = next_token.view(1, -1) + + return new_tokens, new_probs + + +def model_forward(model, x, input_pos): + return model(x, input_pos) + + +def normalize_cache_length( + max_cache_length: float, max_seq_length: int, multiple_of: int = 8 +) -> int: + """ + Computes the absolute cache length given the max_cache_length and max_seq_length. + """ + if 0 < max_cache_length <= 1: + max_cache_length = round(max_seq_length * max_cache_length) + else: + assert int(max_cache_length) == max_cache_length + max_cache_length = int(max_cache_length) + if max_cache_length > max_seq_length: + print( + f"Warning: max_cache_length ({max_cache_length}) is greater than max_seq_length ({max_seq_length}). Setting to {max_seq_length}" + ) + max_cache_length = max_seq_length + return min(find_multiple(max_cache_length, multiple_of), max_seq_length) + + +def setup_caches( + model: Transformer, + tokenizer: TokenizerInterface, + device: torch.device, + max_seq_length: int, + cache_kwargs: dict = None, +): + cache_kwargs["max_seq_length"] = max_seq_length + # Normalize max_cache_length to absolute cache length if provided as a fraction of the max seq sequence length + cache_kwargs["max_cache_length"] = list( + map( + lambda l: normalize_cache_length(l, max_seq_length), + cache_kwargs["max_cache_length"], + ) + ) + + assert ( + model.config.n_layer % len(cache_kwargs["max_cache_length"]) == 0 + ), f'max_cache_length ({len(cache_kwargs["max_cache_length"])}) must be a factor of {model.config.n_layer} layers.' + + tile_size = model.config.n_layer // len(cache_kwargs["max_cache_length"]) + cache_kwargs["max_cache_length"] = [ + item for item in cache_kwargs["max_cache_length"] for _ in range(tile_size) + ] + + if type(cache_kwargs["recent_window"]) != list: + if cache_kwargs["recent_window"] <= 1: + cache_kwargs["recent_window"] = [ + max(1, int(cache_kwargs["recent_window"] * l)) + for l in cache_kwargs["max_cache_length"] + ] + else: + cache_kwargs["recent_window"] = [ + max(1, min(cache_kwargs["recent_window"], l)) + for l in cache_kwargs["max_cache_length"] + ] + + # Gets called twice when model is wrapped in torch.compile which causes an error without the if statement + if type(cache_kwargs["drop_amount"]) != list: + cache_kwargs["drop_amount"] = [ + max(int(cache_kwargs["drop_amount"] * l), 1) + for l in cache_kwargs["max_cache_length"] + ] + + assert cache_kwargs["global_tokens"] <= min( + cache_kwargs["max_cache_length"] + ), "Global tokens must be less than max_cache_length." + + if cache_kwargs["cache_strategy"] == "fastgen": + # We need to pass the special and punctuation token ids to the cache via cache_kwargs + cache_kwargs["token_ids"] = { + "special": tokenizer.special_ids(), + "punctuation": tokenizer.punctuation_ids(), + } + + if "gist" in cache_kwargs["cache_strategy"]: + cache_kwargs["gist_token_id"] = tokenizer.gist_token_id() + + with torch.device(device): + model.setup_caches(max_batch_size=1, **cache_kwargs) + + +def reset_caches(model: Transformer): + model.reset_caches() + + +def get_cache_stats(model: Transformer, prompt_len: int, gen_len: int): + return model.get_cache_stats(prompt_len, gen_len) + + +@torch.no_grad() +def generate( + model: Transformer, + prompt: torch.Tensor, + max_new_tokens: int, + terminator_ids: Optional[list] = None, + gist_token_id: int = -1, + feed_long_prompts: bool = False, + **sampling_kwargs, +) -> torch.Tensor: + """ + Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + """ + + # create an empty tensor of the expected final shape and fill in the current tokens + prompt_length = prompt.size(0) + + device, dtype = prompt.device, prompt.dtype + + min_cache_length = model.min_cache_length() + # Subtract 1 in case we need one generation step over which to compute attention, etc. + max_prompt_len = min_cache_length - 1 + prefix = None + # If we asked to have prompt truncated and fed, we need to do split prompt into prompt and prefix + # We also define a rare yet important edge case: if |prompt| is exactly cache length + # We might have to start evictions before having had a change to record any state (attentions). + # In this scenario let's decrement prompt by 1 and start "generating" on the prefix + if ( + feed_long_prompts and prompt_length > max_prompt_len + ) or prompt_length == min_cache_length: + prompt, prefix = prompt[:max_prompt_len], prompt[max_prompt_len:] + max_new_tokens += len(prefix) + prompt_length = max_prompt_len + # create an empty tensor (all -1) of the expected final shape and fill in the current tokens + # GPT-Fast had this as empty but the values of empty are non-deterministic + seq = torch.full((prompt_length + max_new_tokens,), -1, dtype=dtype, device=device) + seq[:prompt_length] = prompt + input_pos = torch.arange(0, prompt_length, device=device) + + ret = prefill( + model, + prompt.view(1, -1), + input_pos, + next_token=None if prefix is None else prefix[0].view(1), + gist_token_id=gist_token_id, + **sampling_kwargs, + ) + next_token = ret[0].clone() + next_tok_probs = ret[1].clone() + seq[prompt_length] = next_token + + input_pos = torch.tensor([prompt_length], device=device, dtype=torch.int) + generated_tokens, generated_tok_probs = decode_n_tokens( + model, + next_token.view(1, -1), + input_pos, + max_new_tokens - 1, + terminator_ids=terminator_ids, + prefix=None if prefix is None else prefix[1:], + **sampling_kwargs, + ) + if len(generated_tokens) > 0: + seq[prompt_length + 1 : prompt_length + 1 + len(generated_tokens)] = torch.cat( + generated_tokens + ) + + # Truncate seq to first instance of -1 if -1 is present + if -1 in seq: + seq = seq[: torch.where(seq == -1)[0][0]] + + return seq, [next_tok_probs] + generated_tok_probs + + +def load_model(checkpoint_path, device, precision, use_tp): + use_cuda = "cuda" in device + with torch.device("meta"): + model = Transformer.from_name(checkpoint_path.parent.name) + + if "int8" in str(checkpoint_path): + print("Using int8 weight-only quantization!") + from quantize import WeightOnlyInt8QuantHandler + + simple_quantizer = WeightOnlyInt8QuantHandler(model) + model = simple_quantizer.convert_for_runtime() + + if "int4" in str(checkpoint_path): + print("Using int4 weight-only quantization!") + path_comps = checkpoint_path.name.split(".") + groupsize = int(path_comps[-2][1:]) + from quantize import WeightOnlyInt4QuantHandler + + simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) + model = simple_quantizer.convert_for_runtime() + + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + if "model" in checkpoint and "stories" in str(checkpoint_path): + checkpoint = checkpoint["model"] + model.load_state_dict(checkpoint, assign=True) + if use_tp: + from tp import apply_tp + + print("Applying tensor parallel to model ...") + apply_tp(model) + + model = model.to(device=device, dtype=precision) + return model.eval() + + +def get_model_size(model): + model_size = 0 + for name, child in model.named_children(): + if not isinstance(child, torch.nn.Embedding): + for p in itertools.chain(child.parameters(), child.buffers()): + model_size += p.numel() * p.dtype.itemsize + return model_size diff --git a/metric.py b/metric.py new file mode 100644 index 0000000..2fe9399 --- /dev/null +++ b/metric.py @@ -0,0 +1,285 @@ +import os + +import numpy as np +import regex as re +from claudette import Chat, models +from evaluate import load +import regex as re + + +class Metric: + def __init__(self, **kwargs): + self._load_metric(**kwargs) + + def _load_metric(self, **kwargs): + raise NotImplementedError("This method should be overridden by subclasses.") + + def compute(self, prompts, predictions, references): + raise NotImplementedError("This method should be overridden by subclasses.") + + +class Rouge(Metric): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def _load_metric(self, **kwargs): + self.metric = load("rouge", keep_in_memory=True) + + def compute(self, prompts, predictions, references): + return self.metric.compute(predictions=predictions, references=references) + + +class Bleurt(Metric): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def _load_metric(self, **kwargs): + self.metric = load("bleurt", keep_in_memory=True) + + def compute(self, prompts, predictions, references): + return np.mean( + self.metric.compute(predictions=predictions, references=references)[ + "scores" + ] + ) + + +class BertScore(Metric): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def _load_metric(self, **kwargs): + self.metric = load("bertscore", keep_in_memory=True) + + def compute(self, prompts, predictions, references): + result = self.metric.compute( + predictions=predictions, references=references, lang="en" + ) + return { + "precision": np.mean(result["precision"]), + "recall": np.mean(result["recall"]), + "f1": np.mean(result["f1"]), + } + + +class Accuracy(Metric): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def _load_metric(self, **kwargs): + from sklearn.metrics import accuracy_score + + self.metric = accuracy_score + + def compute(self, prompts, predictions, references): + return self.metric(references, predictions) + + +class RulerStringMatch(Metric): + """ + Metric used in RULER. + Reference: https://github.com/hsiehjackson/RULER/blob/main/scripts/eval/synthetic/constants.py + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + @staticmethod + def postprocess_pred(predict_str: str): + predict_str = predict_str.strip() + + # Remove all non-printable characters + np_pattern = re.compile(r"[\x00-\x1f]") + predict_str = np_pattern.sub("\n", predict_str).strip() + + return predict_str + + @staticmethod + def string_match_part(refs, preds): + scores = [ + max([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) + for pred, ref in zip(preds, refs) + ] + score = sum(scores) / len(preds) * 100 + return {"score": round(score, 4)} + + @staticmethod + def string_match_all(refs, preds): + scores = [ + sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref) + for pred, ref in zip(preds, refs) + ] + score = sum(scores) / len(preds) * 100 + return {"score": round(score, 4)} + + def _load_metric(self, **kwargs): + if kwargs.get("match_part", False): + self.metric = self.string_match_part + else: + self.metric = self.string_match_all + + def compute(self, prompts, predictions, references): + predictions = [self.postprocess_pred(pred) for pred in predictions] + return self.metric(references, predictions) + + +REFERENCE_TEMPLATE = """You are shown ground-truth answer(s) and asked to judge the quality of an LLM-generated answer. +Assign it a score from 1-5 where 1 is the worst and 5 is the best based on how similar it is to the ground-truth(s). +Do NOT explain your choice. Simply return a number from 1-5. + +====GROUND TRUTHS==== +{labels} + +====ANSWER==== +{prediction}""" + +PREFILL = "The score (1-5) is:" + + +class LLMRouge(Metric): + def __init__(self, **kwargs) -> None: + assert ( + "ANTHROPIC_API_KEY" in os.environ + ), "Please set the ANTHROPIC_API_KEY environment variable." + super().__init__(**kwargs) + + def _load_metric(self, **kwargs): + name = kwargs.get("name", "haiku") + matching_names = [m for m in models if name in m] + assert len(matching_names) > 0, f"Model name {name} not found in {models}" + assert ( + len(matching_names) == 1 + ), f"Model name {name} found x{len(matching_names)} in {models}" + self.chat = Chat( + matching_names[0], sp="""You are a helpful and concise assistant.""" + ) + + def parse_int(self, text): + return int(re.search(r"\d+", text).group()) + + def compute(self, prompts, predictions, labels): + scores = [] + for p, ls in zip(predictions, labels): + prompt = REFERENCE_TEMPLATE.format(labels="\n---\n".join(ls), prediction=p) + # Clear conversation history + self.chat.h = [] + score = ( + self.chat(prompt, prefill=PREFILL) + .content[0] + .text[len(PREFILL) :] + .strip() + ) + score = self.parse_int(score) + scores.append(score) + return {"llm_rouge": sum(scores) / len(scores)} + + +LLM_JUDGE_TEMPLATE = """You are shown a prompt and asked to assess the quality of an LLM-generated answer on the following dimensions: + +===CRITERIA=== +{criteria} + +Respond with "criteria: score" for each criteria with a newline for each criteria. +Assign a score from 1-5 where 1 is the worst and 5 is the best based on how well the answer meets the criteria. + +====PROMPT==== +{prompt} + +====ANSWER==== +{prediction}""" + + +CRITERIA = { + "helpful": "The answer executes the action requested by the prompt without extraneous detail.", + "coherent": "The answer is logically structured and coherent (ignore the prompt).", + "faithful": "The answer is faithful to the prompt and does not contain false information.", +} + + +class LLMJudge(LLMRouge): + def __init__(self, **kwargs) -> None: + assert ( + "ANTHROPIC_API_KEY" in os.environ + ), "Please set the ANTHROPIC_API_KEY environment variable." + super().__init__(**kwargs) + + self.criteria = list(sorted([k for k in CRITERIA])) + self.criteria_def = "\n".join([f"{k}: {CRITERIA[k]}" for k in self.criteria]) + self.prefill = ( + f"\n\n====SCORES for {', '.join(self.criteria)}====\n\n{self.criteria[0]}:" + ) + + def parse_scorecard(self, scorecard): + try: + return { + k: int(v) + for k, v in dict( + re.findall(rf"({'|'.join(self.criteria)})\W+(\d+)", scorecard) + ).items() + } + except Exception as e: + print(e) + raise Exception( + f"Could not parse LLM-generated scorecard for {self.__class__}:\n{scorecard}" + ) + + def claudette_scorecard(self, prompt, prediction): + prompt = LLM_JUDGE_TEMPLATE.format( + criteria=self.criteria_def, prompt=prompt, prediction=prediction + ) + # Clear conversation history + self.chat.h = [] + scorecard = ( + self.chat(prompt, prefill=self.prefill) + .content[0] + .text[len(self.prefill) - len(self.criteria[0]) - 1 :] + .strip() + ) + return scorecard + + def compute(self, prompts, predictions, labels): + scores = [] + + for prompt, pred in zip(prompts, predictions): + scorecard = self.claudette_scorecard(prompt, pred) + score_dict = self.parse_scorecard(scorecard) + scores.append(score_dict) + + return {k: np.mean([s[k] for s in scores]) for k in self.criteria} + + +METRIC_MAPPING = { + "accuracy": Accuracy, + "bertscore": BertScore, + "bleurt": Bleurt, + "llm-rouge": LLMRouge, + "llm-as-a-judge": LLMJudge, + "rouge": Rouge, + "ruler-string-match": RulerStringMatch, +} + + +class AutoMetric: + def __init__(self): + raise EnvironmentError( + "This class is designed to be instantiated only through the from_name method" + ) + + def from_name(metric_name, **kwargs): + if metric_name not in METRIC_MAPPING: + raise ValueError(f"Invalid metric name: {metric_name}") + return METRIC_MAPPING[metric_name](**kwargs) + + +if __name__ == "__main__": + metric = AutoMetric.from_name("llm-as-a-judge") + predictions = [ + "The answer to 2x2 is 4.", + "The answer to 2x2 is 5.", + ] + labels = [["4"], ["4"]] + prompts = [ + "What is 2x2?", + "What is 2x2?", + ] + print(metric.compute(prompts=prompts, predictions=predictions, labels=None)) diff --git a/mixtral-moe/generate.py b/mixtral-moe/generate.py index 9aa076b..0bf477d 100644 --- a/mixtral-moe/generate.py +++ b/mixtral-moe/generate.py @@ -13,6 +13,7 @@ import torch._dynamo.config import torch._inductor.config + def device_sync(device): if "cuda" in device: torch.cuda.synchronize(device) @@ -24,7 +25,7 @@ def device_sync(device): torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.triton.unique_kernel_names = True -torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future +torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future # support running without installing as a package @@ -37,10 +38,13 @@ def device_sync(device): from tp import maybe_init_dist -def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization +def multinomial_sample_one_no_sync( + probs_sort, +): # Does multinomial sampling without a cuda synchronization q = torch.empty_like(probs_sort).exponential_(1) return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): logits = logits / max(temperature, 1e-5) @@ -51,26 +55,43 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non probs = torch.nn.functional.softmax(logits, dim=-1) return probs + def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): probs = logits_to_probs(logits[0, -1], temperature, top_k) idx_next = multinomial_sample_one_no_sync(probs) return idx_next, probs -def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: + +def prefill( + model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs +) -> torch.Tensor: # input_pos: [B, S] logits = model(x, input_pos) return sample(logits, **sampling_kwargs)[0] -def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + +def decode_one_token( + model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs +) -> Tuple[torch.Tensor, torch.Tensor]: # input_pos: [B, 1] assert input_pos.shape[-1] == 1 logits = model(x, input_pos) return sample(logits, **sampling_kwargs) -def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): + +def decode_n_tokens( + model: Transformer, + cur_token: torch.Tensor, + input_pos: torch.Tensor, + num_new_tokens: int, + callback=lambda _: _, + **sampling_kwargs, +): new_tokens, new_probs = [], [] for i in range(num_new_tokens): - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here + with torch.backends.cuda.sdp_kernel( + enable_flash=False, enable_mem_efficient=False, enable_math=True + ): # Actually better for Inductor to codegen attention here next_token, next_prob = decode_one_token( model, cur_token, input_pos, **sampling_kwargs ) @@ -86,6 +107,7 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc def model_forward(model, x, input_pos): return model(x, input_pos) + @torch.no_grad() def generate( model: Transformer, @@ -93,8 +115,8 @@ def generate( max_new_tokens: int, *, interactive: bool, - callback = lambda x: x, - **sampling_kwargs + callback=lambda x: x, + **sampling_kwargs, ) -> torch.Tensor: """ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. @@ -123,24 +145,34 @@ def generate( input_pos = torch.tensor([T], device=device, dtype=torch.int) - generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) - seq[T + 1:] = torch.cat(generated_tokens) + generated_tokens, _ = decode_n_tokens( + model, + next_token.view(1, -1), + input_pos, + max_new_tokens - 1, + callback=callback, + **sampling_kwargs, + ) + seq[T + 1 :] = torch.cat(generated_tokens) return seq -def encode_tokens(tokenizer, string, bos=True, device='cuda'): + +def encode_tokens(tokenizer, string, bos=True, device="cuda"): tokens = tokenizer.encode(string) if bos: tokens = [tokenizer.bos_id()] + tokens return torch.tensor(tokens, dtype=torch.int, device=device) + def _load_model(checkpoint_path, device, precision, use_tp): - with torch.device('meta'): + with torch.device("meta"): model = Transformer.from_name(checkpoint_path.parent.name) if "int8" in str(checkpoint_path): print("Using int8 weight-only quantization!") from quantize import WeightOnlyBit8QuantHandler + simple_quantizer = WeightOnlyBit8QuantHandler(model, torch.int8) model = simple_quantizer.convert_for_runtime() @@ -149,14 +181,17 @@ def _load_model(checkpoint_path, device, precision, use_tp): if use_tp: from tp import apply_tp + print("Applying tensor parallel to model ...") apply_tp(model) model = model.to(device=device, dtype=precision) return model.eval() + B_INST, E_INST = "[INST]", "[/INST]" + def main( prompt: str = "Hello, my name is", interactive: bool = False, @@ -168,10 +203,9 @@ def main( compile: bool = True, compile_prefill: bool = False, profile: Optional[Path] = None, - device='cuda', + device="cuda", ) -> None: - """Generates text samples based on a pre-trained Transformer model and tokenizer. - """ + """Generates text samples based on a pre-trained Transformer model and tokenizer.""" assert checkpoint_path.is_file(), checkpoint_path tokenizer_path = checkpoint_path.parent / "tokenizer.model" @@ -193,7 +227,7 @@ def main( t0 = time.time() model = _load_model(checkpoint_path, device, precision, use_tp) - device_sync(device=device) # MKG + device_sync(device=device) # MKG print(f"Time to load model: {time.time() - t0:.02f} seconds") tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) @@ -201,25 +235,31 @@ def main( prompt_length = encoded.size(0) torch.manual_seed(1234) - model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())]) + model_size = sum( + [ + p.numel() * p.dtype.itemsize + for p in itertools.chain(model.parameters(), model.buffers()) + ] + ) if compile: torch._inductor.config.assert_indirect_indexing = False global decode_one_token, prefill - decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) + decode_one_token = torch.compile( + decode_one_token, mode="reduce-overhead", fullgraph=True + ) # Uncomment to squeeze more perf out of prefill if args.compile_prefill: prefill = torch.compile(prefill, fullgraph=True, dynamic=True) - aggregate_metrics = { - 'tokens_per_sec': [], + "tokens_per_sec": [], } start = -1 if compile else 0 for i in range(start, num_samples): - device_sync(device=device) # MKG + device_sync(device=device) # MKG if i >= 0 and interactive: prompt = input("What is your prompt? ") if is_chat: @@ -228,8 +268,9 @@ def main( if interactive and i >= 0: buffer = [] - period_id = tokenizer.encode('.')[0] + period_id = tokenizer.encode(".")[0] done_generating = False + def callback(x): nonlocal done_generating if done_generating: @@ -238,13 +279,14 @@ def callback(x): if x.item() == tokenizer.eos_id(): done_generating = True if len(buffer) == 4 or done_generating: - print(''.join(buffer), end='', flush=True) + print("".join(buffer), end="", flush=True) buffer.clear() # print(, end='', flush=True) else: - callback = lambda x : x + callback = lambda x: x t0 = time.perf_counter() import contextlib + if (i != num_samples - 1 or not profile) or (use_tp and rank != 0): prof = contextlib.nullcontext() else: @@ -268,7 +310,7 @@ def callback(x): prof.export_chrome_trace(f"{profile}_rank_{rank}.json") else: prof.export_chrome_trace(f"{profile}.json") - device_sync(device=device) # MKG + device_sync(device=device) # MKG t = time.perf_counter() - t0 if not interactive: @@ -277,32 +319,67 @@ def callback(x): print() tokens_generated = y.size(0) - prompt_length tokens_sec = tokens_generated / t - aggregate_metrics['tokens_per_sec'].append(tokens_sec) - print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") + aggregate_metrics["tokens_per_sec"].append(tokens_sec) + print( + f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec" + ) print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") - print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}") + print( + f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}" + ) print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") -if __name__ == '__main__': +if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Your CLI description.') - - parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') - parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') - parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') - parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') - parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') - parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') - parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') - parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') - parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') - parser.add_argument('--profile', type=Path, default=None, help='Profile path.') - parser.add_argument('--device', type=str, default="cuda", help='device to use') + + parser = argparse.ArgumentParser(description="Your CLI description.") + + parser.add_argument( + "--prompt", type=str, default="Hello, my name is", help="Input prompt." + ) + parser.add_argument( + "--interactive", + action="store_true", + help="Whether to launch in interactive mode", + ) + parser.add_argument("--num_samples", type=int, default=5, help="Number of samples.") + parser.add_argument( + "--max_new_tokens", type=int, default=200, help="Maximum number of new tokens." + ) + parser.add_argument("--top_k", type=int, default=200, help="Top-k for sampling.") + parser.add_argument( + "--temperature", type=float, default=0.8, help="Temperature for sampling." + ) + parser.add_argument( + "--checkpoint_path", + type=Path, + default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), + help="Model checkpoint path.", + ) + parser.add_argument( + "--compile", action="store_true", help="Whether to compile the model." + ) + parser.add_argument( + "--compile_prefill", + action="store_true", + help="Whether to compile the prefill (improves prefill perf, but higher compile times)", + ) + parser.add_argument("--profile", type=Path, default=None, help="Profile path.") + parser.add_argument("--device", type=str, default="cuda", help="device to use") args = parser.parse_args() main( - args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, - args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.device + args.prompt, + args.interactive, + args.num_samples, + args.max_new_tokens, + args.top_k, + args.temperature, + args.checkpoint_path, + args.compile, + args.compile_prefill, + args.profile, + args.device, ) diff --git a/mixtral-moe/model.py b/mixtral-moe/model.py index 9249ac9..2aee686 100644 --- a/mixtral-moe/model.py +++ b/mixtral-moe/model.py @@ -17,6 +17,7 @@ def find_multiple(n: int, k: int) -> int: return n return n + k - (n % k) + @dataclass class ModelArgs: block_size: int = 2048 @@ -46,21 +47,38 @@ def from_name(cls, name: str): if name in transformer_configs: return cls(**transformer_configs[name]) # fuzzy search - config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)] + config = [ + config + for config in transformer_configs + if config in str(name).upper() or config in str(name) + ] assert len(config) == 1, name return cls(**transformer_configs[config[0]]) transformer_configs = { - "Mixtral-8x7B-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2), + "Mixtral-8x7B-v0.1": dict( + block_size=32768, + n_layer=32, + n_head=32, + n_local_heads=8, + dim=4096, + intermediate_size=14336, + rope_base=1000000.0, + num_experts=8, + num_activated_experts=2, + ), } + class KVCache(nn.Module): - def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): + def __init__( + self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16 + ): super().__init__() cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) - self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) - self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) def update(self, input_pos, k_val, v_val): # input_pos: [S], k_val: [B, H, S, D] @@ -73,13 +91,16 @@ def update(self, input_pos, k_val, v_val): return k_out, v_out + class Transformer(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() self.config = config self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) - self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) + self.layers = nn.ModuleList( + TransformerBlock(config) for _ in range(config.n_layer) + ) self.norm = RMSNorm(config.dim, eps=config.norm_eps) self.output = nn.Linear(config.dim, config.vocab_size, bias=False) @@ -89,17 +110,28 @@ def __init__(self, config: ModelArgs) -> None: self.max_seq_length = -1 def setup_caches(self, max_batch_size, max_seq_length): - if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: + if ( + self.max_seq_length >= max_seq_length + and self.max_batch_size >= max_batch_size + ): return head_dim = self.config.dim // self.config.n_head max_seq_length = find_multiple(max_seq_length, 8) self.max_seq_length = max_seq_length self.max_batch_size = max_batch_size for b in self.layers: - b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim) - - self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base) - self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) + b.attention.kv_cache = KVCache( + max_batch_size, max_seq_length, self.config.n_local_heads, head_dim + ) + + self.freqs_cis = precompute_freqs_cis( + self.config.block_size, + self.config.dim // self.config.n_head, + self.config.rope_base, + ) + self.causal_mask = torch.tril( + torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool) + ) def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: assert self.freqs_cis is not None, "Caches must be initialized first" @@ -126,7 +158,9 @@ def __init__(self, config: ModelArgs) -> None: self.ffn_norm = RMSNorm(config.dim, config.norm_eps) self.attention_norm = RMSNorm(config.dim, config.norm_eps) - def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: + def forward( + self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor + ) -> Tensor: h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) out = h + self.block_sparse_moe(self.ffn_norm(h)) return out @@ -156,7 +190,13 @@ def load_hook(self, state_dict, prefix, *args): wv = state_dict.pop(prefix + "wv.weight") state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) - def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + def forward( + self, + x: Tensor, + freqs_cis: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, + ) -> Tensor: bsz, seqlen, _ = x.shape kv_size = self.n_local_heads * self.head_dim @@ -187,17 +227,23 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona class ConditionalFeedForward(nn.Module): def __init__(self, config): super().__init__() - self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) - self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size)) - self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) + self.w1 = nn.Parameter( + torch.empty(config.num_experts, config.intermediate_size, config.dim) + ) + self.w2 = nn.Parameter( + torch.empty(config.num_experts, config.dim, config.intermediate_size) + ) + self.w3 = nn.Parameter( + torch.empty(config.num_experts, config.intermediate_size, config.dim) + ) def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor: - w1_weights = self.w1[expert_indices] # [T, A, D, D] - w3_weights = self.w3[expert_indices] # [T, A, D, D] + w1_weights = self.w1[expert_indices] # [T, A, D, D] + w3_weights = self.w3[expert_indices] # [T, A, D, D] w2_weights = self.w2[expert_indices] # [T, A, D, D] - x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights)) - x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) - expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) + x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights)) + x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) + expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) return expert_outs @@ -208,16 +254,19 @@ def __init__(self, config) -> None: self.cond_ffn = ConditionalFeedForward(config) self.dim = config.dim self.num_activated_experts = config.num_activated_experts + def forward(self, x: Tensor) -> Tensor: x = x.view(-1, self.dim) # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts # x: [T, D] - scores = self.gate(x) # [T, E] + scores = self.gate(x) # [T, E] expert_weights = F.softmax(scores, dim=-1) - expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A] - expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A] + expert_weights, expert_indices = torch.topk( + expert_weights, self.num_activated_experts, dim=-1 + ) # [T, A], [T, A] + expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A] expert_outs = self.cond_ffn(x, expert_indices) - return torch.einsum('tai,ta -> ti', expert_outs, expert_weights) + return torch.einsum("tai,ta -> ti", expert_outs, expert_weights) class RMSNorm(nn.Module): @@ -234,10 +283,10 @@ def forward(self, x: Tensor) -> Tensor: return output * self.weight -def precompute_freqs_cis( - seq_len: int, n_elem: int, base: int = 10000 -) -> Tensor: - freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) +def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor: + freqs = 1.0 / ( + base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) + ) t = torch.arange(seq_len, device=freqs.device) freqs = torch.outer(t, freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) diff --git a/mixtral-moe/quantize.py b/mixtral-moe/quantize.py index 6312863..3a8029d 100644 --- a/mixtral-moe/quantize.py +++ b/mixtral-moe/quantize.py @@ -14,6 +14,7 @@ ##### Quantization Primitives ###### + def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): # assumes symmetric quantization # assumes axis == 0 @@ -51,16 +52,30 @@ def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): ##### Weight-only int8 per-channel quantized code ###### + def replace_linear_weight_only_bit8_per_channel(module, target_dtype): for name, child in module.named_children(): if isinstance(child, nn.Linear) and name != "gate": - setattr(module, name, WeightOnlyBit8Linear(child.in_features, child.out_features, target_dtype=target_dtype)) + setattr( + module, + name, + WeightOnlyBit8Linear( + child.in_features, child.out_features, target_dtype=target_dtype + ), + ) elif isinstance(child, ConditionalFeedForward): num_experts, intermediate_size, dim = child.w1.shape - setattr(module, name, ConditionalFeedForwardBit8(num_experts, intermediate_size, dim, target_dtype=target_dtype)) + setattr( + module, + name, + ConditionalFeedForwardBit8( + num_experts, intermediate_size, dim, target_dtype=target_dtype + ), + ) else: replace_linear_weight_only_bit8_per_channel(child, target_dtype) + class WeightOnlyBit8QuantHandler: def __init__(self, mod, target_dtype): self.mod = mod @@ -71,7 +86,9 @@ def create_quantized_state_dict(self): cur_state_dict = self.mod.state_dict() for fqn, mod in self.mod.named_modules(): if isinstance(mod, torch.nn.Linear) and not fqn.endswith(".gate"): - int8_weight, scales, _ = dynamically_quantize_per_channel(mod.weight.float(), -128, 127, self.target_dtype) + int8_weight, scales, _ = dynamically_quantize_per_channel( + mod.weight.float(), -128, 127, self.target_dtype + ) cur_state_dict[f"{fqn}.weight"] = int8_weight cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype) elif isinstance(mod, ConditionalFeedForward): @@ -84,12 +101,20 @@ def create_quantized_state_dict(self): bit8_weight_list = [] scales_list = [] for expert_idx in range(num_experts): - bit8_weight, scales, _ = dynamically_quantize_per_channel(weight[expert_idx].float(), -128, 127, self.target_dtype) - bit8_weight_list.append(bit8_weight.reshape(1, intermediate_size, dim)) + bit8_weight, scales, _ = dynamically_quantize_per_channel( + weight[expert_idx].float(), -128, 127, self.target_dtype + ) + bit8_weight_list.append( + bit8_weight.reshape(1, intermediate_size, dim) + ) scales_list.append(scales.reshape(1, intermediate_size)) - cur_state_dict[f"{fqn}.{weight_name}"] = torch.cat(bit8_weight_list, dim=0) - cur_state_dict[f"{fqn}.{scales_name}"] = torch.cat(scales_list, dim=0) + cur_state_dict[f"{fqn}.{weight_name}"] = torch.cat( + bit8_weight_list, dim=0 + ) + cur_state_dict[f"{fqn}.{scales_name}"] = torch.cat( + scales_list, dim=0 + ) return cur_state_dict @@ -99,19 +124,28 @@ def convert_for_runtime(self): class WeightOnlyBit8Linear(torch.nn.Module): - __constants__ = ['in_features', 'out_features'] + __constants__ = ["in_features", "out_features"] in_features: int out_features: int weight: torch.Tensor - def __init__(self, in_features: int, out_features: int, bias: bool = True, - device=None, dtype=None, target_dtype=None) -> None: + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + target_dtype=None, + ) -> None: assert target_dtype is not None - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.in_features = in_features self.out_features = out_features - self.register_buffer("weight", torch.empty((out_features, in_features), dtype=target_dtype)) + self.register_buffer( + "weight", torch.empty((out_features, in_features), dtype=target_dtype) + ) self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) def forward(self, input: torch.Tensor) -> torch.Tensor: @@ -124,69 +158,106 @@ def __init__(self, num_experts, intermediate_size, dim, target_dtype): self.target_dtype = target_dtype - self.register_buffer("w1", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype)) - self.register_buffer("w2", torch.empty(num_experts, dim, intermediate_size, dtype=target_dtype)) - self.register_buffer("w3", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype)) - - self.register_buffer("scales1", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16)) - self.register_buffer("scales2", torch.empty(num_experts, dim, dtype=torch.bfloat16)) - self.register_buffer("scales3", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16)) + self.register_buffer( + "w1", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype) + ) + self.register_buffer( + "w2", torch.empty(num_experts, dim, intermediate_size, dtype=target_dtype) + ) + self.register_buffer( + "w3", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype) + ) + + self.register_buffer( + "scales1", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16) + ) + self.register_buffer( + "scales2", torch.empty(num_experts, dim, dtype=torch.bfloat16) + ) + self.register_buffer( + "scales3", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16) + ) def forward(self, x, expert_indices): - w1_weights = self.w1.to(x.dtype)[expert_indices] # [T, A, D, D] - w3_weights = self.w3.to(x.dtype)[expert_indices] # [T, A, D, D] + w1_weights = self.w1.to(x.dtype)[expert_indices] # [T, A, D, D] + w3_weights = self.w3.to(x.dtype)[expert_indices] # [T, A, D, D] w2_weights = self.w2.to(x.dtype)[expert_indices] - x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights) * self.scales1[expert_indices].to(x.dtype)) - x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) * self.scales3[expert_indices].to(x.dtype) - expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) * self.scales2[expert_indices].to(x.dtype) # [T, A, D, D] + x1 = F.silu( + torch.einsum("ti,taoi -> tao", x, w1_weights) + * self.scales1[expert_indices].to(x.dtype) + ) + x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) * self.scales3[ + expert_indices + ].to(x.dtype) + expert_outs = torch.einsum( + "tao, taio -> tai", (x1 * x3), w2_weights + ) * self.scales2[expert_indices].to(x.dtype) # [T, A, D, D] return expert_outs def quantize( checkpoint_path: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth"), - mode: str = 'int8', - label: str = '', + mode: str = "int8", + label: str = "", ) -> None: assert checkpoint_path.is_file(), checkpoint_path - device = 'cpu' + device = "cpu" precision = torch.bfloat16 print("Loading model ...") t0 = time.time() - with torch.device('meta'): + with torch.device("meta"): model = Transformer.from_name(checkpoint_path.parent.name) checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) model.load_state_dict(checkpoint, assign=True) model = model.to(dtype=precision, device=device) - if mode == 'int8': - print("Quantizing model weights for int8 weight-only symmetric per-channel quantization") + if mode == "int8": + print( + "Quantizing model weights for int8 weight-only symmetric per-channel quantization" + ) quant_handler = WeightOnlyBit8QuantHandler(model, target_dtype=torch.int8) quantized_state_dict = quant_handler.create_quantized_state_dict() dir_name = checkpoint_path.parent base_name = checkpoint_path.name - new_base_name = base_name.replace('.pth', f'{label}int8.pth') + new_base_name = base_name.replace(".pth", f"{label}int8.pth") else: raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8,]") quantize_path = dir_name / new_base_name print(f"Writing quantized weights to {quantize_path}") - quantize_path.unlink(missing_ok=True) # remove existing file if one already there + quantize_path.unlink(missing_ok=True) # remove existing file if one already there torch.save(quantized_state_dict, quantize_path) print(f"Quantization complete took {time.time() - t0:.02f} seconds") return -if __name__ == '__main__': + +if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Quantize a model.') - parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.') - parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform') - parser.add_argument('--label', type=str, default='_', help='label to add to output filename') + + parser = argparse.ArgumentParser(description="Quantize a model.") + parser.add_argument( + "--checkpoint_path", + type=Path, + default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), + help="Path to the model checkpoint to be quantized.", + ) + parser.add_argument( + "--mode", + "-q", + type=str, + default="int8", + choices=["int8", "int4", "int4-gptq"], + help="type of quantization to perform", + ) + parser.add_argument( + "--label", type=str, default="_", help="label to add to output filename" + ) args = parser.parse_args() quantize(args.checkpoint_path, args.mode, args.label) diff --git a/mixtral-moe/scripts/convert_hf_checkpoint.py b/mixtral-moe/scripts/convert_hf_checkpoint.py index e659931..1df4dbd 100644 --- a/mixtral-moe/scripts/convert_hf_checkpoint.py +++ b/mixtral-moe/scripts/convert_hf_checkpoint.py @@ -51,13 +51,15 @@ def convert_hf_checkpoint( merged_result = {} for file in sorted(pt_files): - state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) + state_dict = torch.load( + str(file), map_location="cpu", mmap=True, weights_only=True + ) merged_result.update(state_dict) final_result = {} for key, value in merged_result.items(): if "layers" in key: - abstract_key = re.sub(r'.(\d+).', '.{}.', key) - layer_num = re.search(r'\d+', key).group(0) + abstract_key = re.sub(r".(\d+).", ".{}.", key) + layer_num = re.search(r"\d+", key).group(0) new_key = weight_map[abstract_key] if new_key is None: continue @@ -77,9 +79,18 @@ def convert_hf_checkpoint( del final_result[key.replace("wq", "wk")] del final_result[key.replace("wq", "wv")] elif "w1" in key or "w3" in key: - final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).contiguous() + final_result[key] = ( + final_result[key] + .reshape(config.num_experts, config.intermediate_size, config.dim) + .contiguous() + ) elif "w2" in key: - final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).permute(0, 2, 1).contiguous() + final_result[key] = ( + final_result[key] + .reshape(config.num_experts, config.intermediate_size, config.dim) + .permute(0, 2, 1) + .contiguous() + ) elif "gate" in key: final_result[key] = final_result[key].contiguous() @@ -87,11 +98,16 @@ def convert_hf_checkpoint( torch.save(final_result, checkpoint_dir / "model.pth") -if __name__ == '__main__': +if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.') - parser.add_argument('--checkpoint_dir', type=Path, default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf")) - parser.add_argument('--model_name', type=str, default=None) + + parser = argparse.ArgumentParser(description="Convert HuggingFace checkpoint.") + parser.add_argument( + "--checkpoint_dir", + type=Path, + default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf"), + ) + parser.add_argument("--model_name", type=str, default=None) args = parser.parse_args() convert_hf_checkpoint( diff --git a/mixtral-moe/scripts/download.py b/mixtral-moe/scripts/download.py index d1505ef..9f0c934 100644 --- a/mixtral-moe/scripts/download.py +++ b/mixtral-moe/scripts/download.py @@ -11,20 +11,38 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None: from huggingface_hub import snapshot_download + os.makedirs(f"checkpoints/{repo_id}", exist_ok=True) try: - snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token, ignore_patterns="*.safetensors") + snapshot_download( + repo_id, + local_dir=f"checkpoints/{repo_id}", + local_dir_use_symlinks=False, + token=hf_token, + ignore_patterns="*.safetensors", + ) except HTTPError as e: if e.response.status_code == 401: - print("You need to pass a valid `--hf_token=...` to download private checkpoints.") + print( + "You need to pass a valid `--hf_token=...` to download private checkpoints." + ) else: raise e -if __name__ == '__main__': + +if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Download data from HuggingFace Hub.') - parser.add_argument('--repo_id', type=str, default="checkpoints/meta-llama/llama-2-7b-chat-hf", help='Repository ID to download from.') - parser.add_argument('--hf_token', type=str, default=None, help='HuggingFace API token.') + + parser = argparse.ArgumentParser(description="Download data from HuggingFace Hub.") + parser.add_argument( + "--repo_id", + type=str, + default="checkpoints/meta-llama/llama-2-7b-chat-hf", + help="Repository ID to download from.", + ) + parser.add_argument( + "--hf_token", type=str, default=None, help="HuggingFace API token." + ) args = parser.parse_args() hf_download(args.repo_id, args.hf_token) diff --git a/mixtral-moe/tp.py b/mixtral-moe/tp.py index 75336b5..289515d 100644 --- a/mixtral-moe/tp.py +++ b/mixtral-moe/tp.py @@ -17,17 +17,21 @@ def _get_rank() -> int: return int(os.environ.get("LOCAL_RANK", "0")) + def is_local(): return _get_rank() == 0 + def local_break(): if is_local(): breakpoint() dist.barrier() + def _get_world_size() -> int: return int(os.environ.get("LOCAL_WORLD_SIZE", "1")) + def maybe_init_dist() -> Optional[int]: try: # provided by torchrun @@ -45,23 +49,25 @@ def maybe_init_dist() -> Optional[int]: dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) return rank + rank = _get_rank() world_size = _get_world_size() + def shard(x, dim): assert x.size(dim=dim) % world_size == 0 return torch.tensor_split(x, world_size, dim=dim)[rank] -def _apply_tp_linear(linear: nn.Linear, style: str, weight_splits: List[int] = []) -> None: + +def _apply_tp_linear( + linear: nn.Linear, style: str, weight_splits: List[int] = [] +) -> None: rank = _get_rank() world_size = _get_world_size() # Linear's weight matrix is transposed, and is of shape # (linear.out_features, linear.in_features) - dim_lookup = { - "colwise": (0, "out_features"), - "rowwise": (1, "in_features") - } + dim_lookup = {"colwise": (0, "out_features"), "rowwise": (1, "in_features")} assert style in dim_lookup shard_dim, size_attr = dim_lookup[style] @@ -73,7 +79,7 @@ def shard_qkv(qkv, dim, weight_splits): q = shard(q, dim) k = shard(k, dim) v = shard(v, dim) - return torch.cat((q,k,v), dim=dim) + return torch.cat((q, k, v), dim=dim) # shard if weight_splits: @@ -102,13 +108,20 @@ def _apply_tp_moe_ffn(mlp: MOEFeedForward) -> None: mlp.cond_ffn.w2 = nn.Parameter(shard(mlp.cond_ffn.w2, 2), requires_grad=False) if hasattr(mlp.cond_ffn, "scales1"): - mlp.cond_ffn.scales1 = nn.Parameter(shard(mlp.cond_ffn.scales1, 1), requires_grad=False) - mlp.cond_ffn.scales3 = nn.Parameter(shard(mlp.cond_ffn.scales3, 1), requires_grad=False) + mlp.cond_ffn.scales1 = nn.Parameter( + shard(mlp.cond_ffn.scales1, 1), requires_grad=False + ) + mlp.cond_ffn.scales3 = nn.Parameter( + shard(mlp.cond_ffn.scales3, 1), requires_grad=False + ) mlp.cond_ffn.scales2 = nn.Parameter(mlp.cond_ffn.scales2, requires_grad=False) world_size = _get_world_size() - mlp.cond_ffn.register_forward_hook(lambda _module, _input, output: funcol.all_reduce( - output, "sum", list(range(world_size)))) + mlp.cond_ffn.register_forward_hook( + lambda _module, _input, output: funcol.all_reduce( + output, "sum", list(range(world_size)) + ) + ) def _apply_tp_attn(attn: Attention) -> None: @@ -127,8 +140,11 @@ def _apply_tp_attn(attn: Attention) -> None: attn.head_dim = attn.dim // attn.n_head attn.n_local_heads = attn.n_local_heads // world_size - attn.register_forward_hook(lambda _module, _input, output: funcol.all_reduce( - output[0], "sum", list(range(world_size)))) + attn.register_forward_hook( + lambda _module, _input, output: funcol.all_reduce( + output[0], "sum", list(range(world_size)) + ) + ) def _apply_tp_Transformer(Transformer: Transformer) -> None: diff --git a/model.py b/model.py index 0660bc2..4e5469d 100644 --- a/model.py +++ b/model.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from dataclasses import dataclass +from collections import defaultdict from typing import Optional import torch @@ -11,12 +12,16 @@ from torch import Tensor from torch.nn import functional as F +from attention_utils import scaled_dot_product_attention +from cache import get_cache_constructor + def find_multiple(n: int, k: int) -> int: if n % k == 0: return n return n + k - (n % k) + @dataclass class ModelArgs: block_size: int = 2048 @@ -29,6 +34,8 @@ class ModelArgs: head_dim: int = 64 rope_base: float = 10000 norm_eps: float = 1e-5 + attention_bias: bool = False + max_length: int = 4096 def __post_init__(self): if self.n_local_heads == -1: @@ -44,47 +51,112 @@ def from_name(cls, name: str): if name in transformer_configs: return cls(**transformer_configs[name]) # fuzzy search - config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)] + config = [ + config + for config in transformer_configs + if config in str(name).upper() or config in str(name) + ] # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match, # take longer name (as it have more symbols matched) if len(config) > 1: config.sort(key=len, reverse=True) - assert len(config[0]) != len(config[1]), name # make sure only one 'best' match + assert len(config[0]) != len( + config[1] + ), name # make sure only one 'best' match return cls(**transformer_configs[config[0]]) transformer_configs = { - "CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim = 4096, rope_base=1000000), + "CodeLlama-7b-Python-hf": dict( + block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000 + ), "7B": dict(n_layer=32, n_head=32, dim=4096), "13B": dict(n_layer=40, n_head=40, dim=5120), "30B": dict(n_layer=60, n_head=52, dim=6656), - "34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000), # CodeLlama-34B-Python-hf - "70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672), - "Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000), + "34B": dict( + n_layer=48, + n_head=64, + dim=8192, + vocab_size=32000, + n_local_heads=8, + intermediate_size=22016, + rope_base=1000000, + ), # CodeLlama-34B-Python-hf + "70B": dict( + n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672 + ), + "Mistral-7B": dict( + n_layer=32, + n_head=32, + n_local_heads=8, + dim=4096, + intermediate_size=14336, + vocab_size=32000, + ), "stories15M": dict(n_layer=6, n_head=6, dim=288), "stories110M": dict(n_layer=12, n_head=12, dim=768), - "Llama-3-8B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256), + "Meta-Llama-3-8B": dict( + block_size=8192, + n_layer=32, + n_head=32, + n_local_heads=8, + dim=4096, + intermediate_size=14336, + vocab_size=128256, + rope_base=500000, + max_length=8192, + ), + "Meta-Llama-3-8B-Instruct-4-Layers": dict( + block_size=8192, + n_layer=4, + n_head=32, + n_local_heads=8, + dim=4096, + intermediate_size=14336, + vocab_size=128256, + rope_base=500000, + max_length=8192, + ), + "Qwen2-1.5B-Instruct": dict( + block_size=32768, + n_layer=28, + n_head=12, + n_local_heads=2, + dim=1536, + intermediate_size=8960, + vocab_size=151936, + rope_base=1000000, + attention_bias=True, + norm_eps=1e-6, + max_length=32768, + ), + "Qwen2-0.5B-Instruct": dict( + block_size=32768, + n_layer=24, + n_head=14, + n_local_heads=2, + dim=896, + intermediate_size=4864, + vocab_size=151936, + rope_base=1000000, + attention_bias=True, + norm_eps=1e-6, + max_length=32768, + ), + "Meta-Llama-3-8B-gist-finetune": dict( + block_size=8192, + n_layer=32, + n_head=32, + n_local_heads=8, + dim=4096, + intermediate_size=14336, + vocab_size=128257, + rope_base=500000 + ), } -class KVCache(nn.Module): - def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): - super().__init__() - cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) - self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) - self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) - - def update(self, input_pos, k_val, v_val): - # input_pos: [S], k_val: [B, H, S, D] - assert input_pos.shape[0] == k_val.shape[2] - - k_out = self.k_cache - v_out = self.v_cache - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val - - return k_out, v_out class Transformer(nn.Module): def __init__(self, config: ModelArgs) -> None: @@ -92,42 +164,89 @@ def __init__(self, config: ModelArgs) -> None: self.config = config self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) - self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) + self.layers = nn.ModuleList( + TransformerBlock(config) for _ in range(config.n_layer) + ) self.norm = RMSNorm(config.dim, eps=config.norm_eps) self.output = nn.Linear(config.dim, config.vocab_size, bias=False) self.freqs_cis: Optional[Tensor] = None - self.mask_cache: Optional[Tensor] = None - self.max_batch_size = -1 - self.max_seq_length = -1 - def setup_caches(self, max_batch_size, max_seq_length): - if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: - return + # Fixed for now + self.max_batch_size = 1 + + def setup_caches(self, **kwargs): + cache_strategy = kwargs.pop("cache_strategy") + head_dim = self.config.dim // self.config.n_head - max_seq_length = find_multiple(max_seq_length, 8) - self.max_seq_length = max_seq_length - self.max_batch_size = max_batch_size + dtype = self.output.weight.dtype # For quantized layers, dtype is encoded in scales if hasattr(self.output, "scales"): dtype = self.output.scales.dtype elif hasattr(self.output, "scales_and_zeros"): dtype = self.output.scales_and_zeros.dtype - for b in self.layers: - b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype) - - self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype) - self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) - - def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + for layer_idx, b in enumerate(self.layers): + cache_constructor, relevant_kwargs = get_cache_constructor( + cache_strategy=cache_strategy + ) + # Only pass in the kwargs we need for the cache we chose (useful especially for debugging) + layerwise_keys = {"max_cache_length", "drop_amount", "recent_window"} + layer_kwargs = { + k: kwargs[k][layer_idx] if k in layerwise_keys else kwargs[k] + for k in relevant_kwargs + } + b.attention.kv_cache = cache_constructor( + self.max_batch_size, + self.config.n_local_heads, + head_dim, + dtype, + **layer_kwargs, + ) + + self.freqs_cis = precompute_freqs_cis( + self.config.block_size, + self.config.dim // self.config.n_head, + self.config.rope_base, + dtype, + ) + + def reset_caches(self): + for layer in self.layers: + layer.attention.kv_cache.reset() + + def get_cache_stats(self, prompt_len, gen_len): + stats = {} + final_seq_len = prompt_len + gen_len + avgs = defaultdict(list) + for layer_idx, layer in enumerate(self.layers): + stat = layer.attention.kv_cache.compute_statistics( + seq_len=torch.tensor(final_seq_len) + ) + for k, v in stat.items(): + stats[f"{k}_{layer_idx}"] = v + avgs[k].append(v) + + for k, v in avgs.items(): + stats[f"{k}_avg"] = sum(v) / len(v) + + return stats + + def min_cache_length(self): + return min([layer.attention.kv_cache.max_cache_length for layer in self.layers]) + + def forward( + self, + idx: Tensor, + input_pos: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + ) -> Tensor: assert self.freqs_cis is not None, "Caches must be initialized first" - mask = self.causal_mask[None, None, input_pos] freqs_cis = self.freqs_cis[input_pos] x = self.tok_embeddings(idx) for i, layer in enumerate(self.layers): - x = layer(x, input_pos, freqs_cis, mask) + x = layer(x, idx, input_pos, freqs_cis, mask) x = self.norm(x) logits = self.output(x) return logits @@ -145,8 +264,17 @@ def __init__(self, config: ModelArgs) -> None: self.ffn_norm = RMSNorm(config.dim, config.norm_eps) self.attention_norm = RMSNorm(config.dim, config.norm_eps) - def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: - h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) + def forward( + self, + x: Tensor, + input_ids: Tensor, + input_pos: Tensor, + freqs_cis: Tensor, + mask: Tensor, + ) -> Tensor: + h = x + self.attention( + self.attention_norm(x), input_ids, freqs_cis, mask, input_pos + ) out = h + self.feed_forward(self.ffn_norm(h)) return out @@ -158,7 +286,7 @@ def __init__(self, config: ModelArgs): total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim # key, query, value projections for all heads, but in a batch - self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=config.attention_bias) self.wo = nn.Linear(config.dim, config.dim, bias=False) self.kv_cache = None @@ -175,7 +303,14 @@ def load_hook(self, state_dict, prefix, *args): wv = state_dict.pop(prefix + "wv.weight") state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) - def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + def forward( + self, + x: Tensor, + input_ids: Tensor, + freqs_cis: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, + ) -> Tensor: bsz, seqlen, _ = x.shape kv_size = self.n_local_heads * self.head_dim @@ -190,12 +325,33 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) - if self.kv_cache is not None: - k, v = self.kv_cache.update(input_pos, k, v) - - k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) - v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + k, v, kv_mask, attn_callback = self.kv_cache.update( + input_pos, k, v, input_ids=input_ids + ) + + k_rep = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v_rep = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + if kv_mask is not None: + kv_mask = kv_mask.repeat_interleave( + self.n_head // self.n_local_heads, dim=1 + ) + + y, attn = scaled_dot_product_attention( + q, + k_rep, + v_rep, + attn_mask=mask if mask is not None else kv_mask, + dropout_p=0.0, + return_attn=attn_callback and attn_callback["func"] is not None, + **{} if attn_callback is None else attn_callback.get("kwargs", {}), + ) + + if attn_callback: + # Mean pool over the grouped queries (average over self.n_head // self.n_local_heads) + attn = attn.view( + bsz, self.n_local_heads, self.n_head // self.n_local_heads, seqlen, -1 + ).mean(dim=2) + attn_callback["func"](input_pos, input_ids, k, v, attn) y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) @@ -229,10 +385,11 @@ def forward(self, x: Tensor) -> Tensor: def precompute_freqs_cis( - seq_len: int, n_elem: int, base: int = 10000, - dtype: torch.dtype = torch.bfloat16 + seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16 ) -> Tensor: - freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) + freqs = 1.0 / ( + base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) + ) t = torch.arange(seq_len, device=freqs.device) freqs = torch.outer(t, freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) diff --git a/prompt_compression.py b/prompt_compression.py new file mode 100644 index 0000000..d0b4cc1 --- /dev/null +++ b/prompt_compression.py @@ -0,0 +1,191 @@ +import torch +from abc import ABC, abstractmethod + + +class PromptCompressor(ABC): + def __init__(self, head_specific, **kwargs) -> None: + # Assign each kwarg as an attribute of the class + for key, value in kwargs.items(): + setattr(self, key, value) + + self.head_specific = head_specific + assert self.is_compatible(), f"Prompt compressor ({self.__class__.__name__}) is not compatible with the chosen cache strategy." + + @abstractmethod + def requires_attn(self) -> bool: + pass + + @abstractmethod + def __call__(self, *args: torch.Any, **kwds: torch.Any) -> torch.Any: + pass + + @abstractmethod + def is_compatible(self) -> bool: + pass + + +class PromptCompressorRandom(PromptCompressor): + def __init__(self, head_specific, **kwargs) -> None: + super().__init__(head_specific, **kwargs) + + def is_compatible(self) -> bool: + # Can be used with any cache + return True + + def requires_attn(self) -> bool: + return False + + def __call__(self, input_pos, k_val, v_val): + seq_len = input_pos.shape[0] + global_idxs = torch.arange(self.global_tokens, device=input_pos.device) + rand_idxs = ( + ( + self.global_tokens + + torch.randperm(seq_len - self.global_tokens, device=input_pos.device)[ + : self.max_cache_length - self.global_tokens + ] + ) + .sort() + .values + ) + keep_idxs = torch.cat([global_idxs, rand_idxs], dim=0) + assert len(keep_idxs) == self.max_cache_length + k_val = k_val[:, :, keep_idxs] + v_val = v_val[:, :, keep_idxs] + return keep_idxs, k_val, v_val + + +class PromptCompressorRecentGlobal(PromptCompressor): + def __init__(self, head_specific, **kwargs) -> None: + super().__init__(head_specific, **kwargs) + + def is_compatible(self) -> bool: + # Can be used with any cache + return True + + def requires_attn(self) -> bool: + return False + + def __call__(self, input_pos, k_val, v_val): + # [global; ...; window - global] --> [global; window - global] + # Indices for first global_tokens tokens and last (window - global_tokens) tokens + # Making this a tensor seems to give a speedup, but I haven't fully benchmarked + keep_idxs = torch.tensor( + list(range(self.global_tokens)) + + list( + range( + input_pos.shape[0] - self.max_cache_length + self.global_tokens, + input_pos.shape[0], + ) + ), + dtype=torch.long, + device=k_val.device, + ) + assert len(keep_idxs) == self.max_cache_length + k_val = k_val[:, :, keep_idxs] + v_val = v_val[:, :, keep_idxs] + return keep_idxs, k_val, v_val + + +class PromptCompressorSnapKV(PromptCompressor): + """ + Use SnapKV to compress the prompt + Inspired by the pseudo code on Page 7 of https://arxiv.org/abs/2404.14469 + """ + + def __init__(self, head_specific, **kwargs) -> None: + super().__init__(head_specific, **kwargs) + + self.kernel_size = 5 + self.observation_len = 16 + + self.pool = torch.nn.AvgPool1d( + self.kernel_size, + stride=1, + padding=self.kernel_size // 2, + ceil_mode=False, + count_include_pad=False, + ) + + def is_compatible(self) -> bool: + # Can only be used with head-specific KV-caches + return self.head_specific + + def requires_attn(self) -> bool: + return True + + def __call__(self, input_pos, k_val, v_val, attn): + seq_len = input_pos.shape[0] + obs_len = min(self.observation_len, seq_len) + + priority = attn[:, :, -obs_len:, :].mean(dim=2) + prev_shape = priority.shape + + # We'll be returning the attention history so we need to keep a copy before it's modified + attn_history = priority.clone() + priority = self.pool(priority) + assert ( + priority.shape == prev_shape + ), f"Pooling operation should not change the dimension: {prev_shape} -> {priority.shape}" + priority[:, :, -obs_len:] = 1.0 # Ensure the observation window is selected + priority[:, :, : self.global_tokens] = ( + 1.0 # Ensure the global tokens are selected + ) + keep_idxs = ( + priority.topk(self.max_cache_length, dim=-1).indices.sort(dim=-1).values + ) + + attn_history = attn_history.gather(2, keep_idxs) + + keep_idxs_rep = keep_idxs.unsqueeze(-1).expand(-1, -1, -1, k_val.shape[-1]) + k_val_compressed = k_val.gather(2, keep_idxs_rep) + v_val_compressed = v_val.gather(2, keep_idxs_rep) + + return keep_idxs.squeeze(0), k_val_compressed, v_val_compressed, attn_history + + +class PromptCompressorL2(PromptCompressor): + def __init__(self, head_specific, **kwargs) -> None: + super().__init__(head_specific, **kwargs) + + def is_compatible(self) -> bool: + return self.head_specific + + def requires_attn(self) -> bool: + return False + + def __call__(self, input_pos, k_val, v_val): + key_norm = torch.linalg.vector_norm(k_val, ord=2, dim=-1) + + # Give low score to global and recent tokens + locality_mask = torch.logical_or( + input_pos < self.global_tokens, + input_pos >= input_pos.shape[0] - self.recent_window, + ).view(1, 1, -1) + + eviction_scores = key_norm.masked_fill(locality_mask, float("-inf")) + + keep_idxs = ( + eviction_scores.topk(self.max_cache_length, dim=-1, largest=False) + .indices.sort(dim=-1) + .values + ) + + keep_idxs_rep = keep_idxs.unsqueeze(-1).expand(-1, -1, -1, k_val.shape[-1]) + k_val_compressed = k_val.gather(2, keep_idxs_rep) + v_val_compressed = v_val.gather(2, keep_idxs_rep) + + return keep_idxs, k_val_compressed, v_val_compressed + + +def prompt_compressor_constructor(strategy): + if strategy == "recent_global": + return PromptCompressorRecentGlobal + elif strategy == "snapkv": + return PromptCompressorSnapKV + elif strategy == "l2": + return PromptCompressorL2 + elif strategy == "random": + return PromptCompressorRandom + else: + raise ValueError(f"Unknown prompt compression strategy: {strategy}") diff --git a/prompts/long_prompt_long_output.json b/prompts/long_prompt_long_output.json new file mode 100644 index 0000000..5a8ad2a --- /dev/null +++ b/prompts/long_prompt_long_output.json @@ -0,0 +1,4 @@ +{ + "instruction": "You are an architect tasked with drawing up plans for a modern residential house.\n\nArchitectural Plan Creation Instructions\n\nObjective:\nCreate a comprehensive set of architectural plans for a modern residential house. The plans should include detailed layouts, elevations, sections, and necessary annotations to guide the construction process. The design should focus on functionality, aesthetics, sustainability, and compliance with local building codes.\n\nRequirements:\n\nGeneral Layout:\n\nTotal area: Approximately 2,500 square feet.\nNumber of floors: Two.\nNumber of bedrooms: Four (including a master suite).\nNumber of bathrooms: Three full bathrooms and one half bathroom.\nCommon areas: Open-plan kitchen, dining area, living room, and a study/office.\nAdditional spaces: Laundry room, garage (for two cars), storage rooms, and a small basement.\nSite Plan:\n\nInclude property boundaries, adjacent streets, and any existing structures.\nShow the placement of the house, driveway, pathways, garden, and outdoor living spaces (e.g., patio, deck).\nInclude landscaping elements like trees, shrubs, and lawn areas.\nFloor Plans:\n\nGround Floor: Include entryway, living spaces, kitchen, one bedroom (guest room), one full bathroom, and access to the garage.\nSecond Floor: Include master suite with attached bathroom and walk-in closet, two additional bedrooms, one full bathroom, and a study/office.\nIndicate all door and window placements, furniture layouts, and circulation paths.\nElevations:\n\nProvide front, rear, and side elevations.\nShow the external appearance, including the roof design, facade materials, window and door placements, and any architectural features (e.g., balconies, porches).\nSections:\n\nInclude at least two sections (one longitudinal and one cross-sectional) showing internal details.\nHighlight the relationship between different floors and ceiling heights.\nShow structural elements like beams, columns, and floor slabs.\nRoof Plan:\n\nIndicate the roof slope, materials, drainage system, and any roof features (e.g., skylights, chimneys).\nElectrical and Plumbing Plans:\n\nShow the layout of electrical outlets, switches, lighting fixtures, and major appliances.\nInclude the plumbing layout for water supply and drainage, showing the location of pipes, fixtures, and connections.\nMaterials and Finishes:\n\nSpecify the materials for walls, floors, ceilings, and roofs.\nInclude details on interior and exterior finishes (e.g., paint, tiles, cladding).\nSustainability Features:\n\nIncorporate energy-efficient systems (e.g., HVAC, solar panels).\nUse sustainable building materials.\nPlan for natural lighting and ventilation.\nInclude rainwater harvesting and greywater recycling systems if possible.\nCompliance:\n\nEnsure the design complies with local building codes and regulations.\nInclude necessary annotations and notes for construction guidelines.\n\n", + "input": "You must return the following:\n- Include a detailed list of materials and specifications.\n- Add a cover sheet with project title, address, date, and designer's name.\n- Add a sheet for each component with detailed plans.\n- Ensure all documents are clearly labeled and organized." +} \ No newline at end of file diff --git a/prompts/long_prompt_short_output.json b/prompts/long_prompt_short_output.json new file mode 100644 index 0000000..ad5b46f --- /dev/null +++ b/prompts/long_prompt_short_output.json @@ -0,0 +1,4 @@ +{ + "instruction": "Carefully read the beginning of the Wikipedia page on the Guggenheim museum. You will be asked to answer a question at the end.\n\n ### Introduction The Solomon R. Guggenheim Museum, often referred to as The Guggenheim, is an art museum at 1071 Fifth Avenue between 88th and 89th Streets on the Upper East Side of Manhattan in New York City. It hosts a permanent collection of Impressionist, Post-Impressionist, early Modern, and contemporary art and also features special exhibitions throughout the year. It was established by the Solomon R. Guggenheim Foundation in 1939 as the Museum of Non-Objective Painting, under the guidance of its first director, Hilla von Rebay. The museum adopted its current name in 1952, three years after the death of its founder Solomon R. Guggenheim. It continues to be operated and owned by the Solomon R. Guggenheim Foundation. The museum's building, a landmark work of 20th-century architecture designed by Frank Lloyd Wright, drew controversy for the unusual shape of its display spaces and took 15 years to design and build; it was completed in 1959. It consists of a six-story, bowl-shaped main gallery to the south, a four-story \"monitor\" to the north, and a ten-story annex to the northeast. A six-story helical ramp extends along the main gallery's perimeter, under a central ceiling skylight. The Thannhauser Collection is housed within the top three stories of the monitor, and there are additional galleries in the annex and a learning center in the basement. The museum building's design was controversial when it was completed but was widely praised afterward. The building underwent extensive renovations from 1990 to 1992, when the annex was built, and it was renovated again from 2005 to 2008. The museum's collection has grown over the decades and is founded upon several important private collections, including those of Guggenheim, Karl Nierendorf, Katherine Sophie Dreier, Justin Thannhauser, Rebay, Giuseppe Panza, Robert Mapplethorpe, and the Bohen Foundation. The collection, which includes around 8,000 works as of 2022, is shared with sister museums in Bilbao, Spain, and Venice, Italy. In 2023, nearly 861,000 people visited the museum. # History ## Early years and Hilla Rebay Solomon R. Guggenheim, a member of a wealthy mining family, began collecting works of the old masters in the 1890s. In 1926, he met artist Hilla von Rebay, who introduced him to European avant-garde art, in particular abstract art that she felt had a spiritual and utopian aspect (non-objective art). Guggenheim completely changed his collecting strategy, turning to the work of Wassily Kandinsky, among others. He began to display his collection to the public at his apartment in the Plaza Hotel in New York City. Guggenheim and Rebay initially considered building a museum at Rockefeller Center in Manhattan. As the collection grew, Guggenheim established the Solomon R. Guggenheim Foundation, in 1937, to foster the appreciation of modern art. The foundation's first venue, the Museum of Non-Objective Painting, opened in 1939, under Rebay's direction, at 24 East 54th Street in midtown Manhattan. Under her guidance, Guggenheim sought to include in the collection the most important examples of non-objective art by early modernists. He wanted to display the collection at the 1939 New York World's Fair in Queens, but Rebay advocated for a more permanent location in Manhattan. By the early 1940s, the foundation had accumulated such a large collection of avant-garde paintings that the need for a permanent museum was apparent, and Rebay wanted to establish it before Guggenheim died. ## Design process In 1943, Rebay and Guggenheim wrote a letter to Frank Lloyd Wright asking him to design a structure to house and display the collection. Rebay thought the 76-year-old Wright was dead, but Guggenheim's wife Irene Rothschild Guggenheim knew better and suggested that Rebay contact him. Wright accepted the opportunity to experiment with his \"organic\" style in an urban setting, saying that he had never seen a museum that was \"properly designed\". He was hired to design the building in June 1943. He was to receive a 10 percent commission on the project, which was expected to cost at least $1 million. It took him 15 years, more than 700 sketches, and six sets of working drawings to create and complete the museum, after a series of difficulties and delays; the cost eventually doubled from the initial estimate. Rebay envisioned a space that would facilitate a new way of seeing modern art. She wrote Wright that \"each of these great masterpieces should be organized into space, and only you ... would test the possibilities to do so. ... I want a temple of spirit, a monument!\" Critic Paul Goldberger later wrote that Wright's modernist building was a catalyst for change, making it \"socially and culturally acceptable for an architect to design a highly expressive, intensely personal museum. In this sense almost every museum of our time is a child of the Guggenheim.\" The Guggenheim is the only museum Wright designed; its urban location required him to design it in a vertical rather than horizontal form, far different from his earlier, rural works. Since he was not licensed as an architect in New York, he relied on Arthur Cort Holden, of the architectural firm Holden, McLaughlin & Associates, to deal with New York City's Board of Standards and Appeals. From 1943 to early 1944, Wright produced four differing designs. One had a hexagonal shape and level floors for the galleries, though all the others had circular schemes and used a ramp continuing around the building. In his notes, he indicated that he wanted a \"well proportioned floor space from bottom to top—a wheel chair going around and up and down\". His original concept was called an inverted \"ziggurat\", because it resembled the steep steps on the ziggurats built in ancient Mesopotamia. Several architecture professors have speculated that the helical ramp and glass dome of Giuseppe Momo's 1932 staircase at the Vatican Museums was an inspiration for Wright's ramp and atrium.", + "input": "Question: Which is the largest number? A) Frank Lloyd Wright's age in 1943. B) The size of the collection at the Guggenheim. C) The building number of the museum's first venue. D) The number of sketches it took Frank Lloyd Wright to create the museum." +} \ No newline at end of file diff --git a/prompts/noise_qa.json b/prompts/noise_qa.json new file mode 100644 index 0000000..3b7066f --- /dev/null +++ b/prompts/noise_qa.json @@ -0,0 +1,4 @@ +{ + "instruction": "Answer the Question below. Ignore the \"*\".", + "input": "**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\nWhat is (10 * 10) - 5? Explain how you arrived at the answer as if you were helping a 5 year old just starting to learn math.\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n***********\n***********\n***********\n***********\n***********\n***********\n***********\n***********\n***********\n***********\n***********\n***********\n***********\n**********\n**********\n**********\n**********\n**********" +} \ No newline at end of file diff --git a/prompts/reverse_list.json b/prompts/reverse_list.json new file mode 100644 index 0000000..d43abf7 --- /dev/null +++ b/prompts/reverse_list.json @@ -0,0 +1,4 @@ +{ + "instruction": "Write this list in reverse order:", + "input": "a dog named Remy, a homemade hat, a wicked witch, a head of broccoli, a really smelly fried egg, a famous baseball player named Aaron Judge, and a mysterious aunt." +} \ No newline at end of file diff --git a/prompts/short_prompt_long_output.json b/prompts/short_prompt_long_output.json new file mode 100644 index 0000000..29fd6ac --- /dev/null +++ b/prompts/short_prompt_long_output.json @@ -0,0 +1,4 @@ +{ + "instruction": "Write a detailed textbook on how to build a house from scratch.", + "input": "Write a separate chapter for each stage from initial planning to furnishing." +} \ No newline at end of file diff --git a/prompts/short_prompt_short_output.json b/prompts/short_prompt_short_output.json new file mode 100644 index 0000000..6e70249 --- /dev/null +++ b/prompts/short_prompt_short_output.json @@ -0,0 +1,3 @@ +{ + "instruction": "Which architect designed the Guggenheim?" +} \ No newline at end of file diff --git a/quantize.py b/quantize.py index fb56642..cbf6784 100644 --- a/quantize.py +++ b/quantize.py @@ -21,6 +21,7 @@ ##### Quantization Primitives ###### + def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): # assumes symmetric quantization # assumes axis == 0 @@ -55,6 +56,7 @@ def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): return quant, scales, zero_points + def get_group_qparams(w, n_bit=4, groupsize=128): # needed for GPTQ with padding if groupsize > w.shape[-1]: @@ -161,6 +163,7 @@ def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128): w_int32, scales, zeros, n_bit, groupsize ) + class QuantHandler: def __init__(self, mod): self.mod = mod @@ -171,6 +174,7 @@ def create_quantized_state_dict(self) -> "StateDict": def convert_for_runtime(self) -> "nn.Module": pass + class GPTQQuantHandler(QuantHandler): """ This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class. @@ -233,6 +237,7 @@ class GPTQQuantHandler(QuantHandler): names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the corresponding quantized weights and qparams. """ + def __init__(self): assert self.mod is not None assert self.get_qparams_func is not None @@ -242,7 +247,14 @@ def __init__(self): assert self.make_names_and_values_dict_func is not None @staticmethod - def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) -> "MultiInput": + def get_inputs( + model, + tokenizer, + calibration_tasks, + calibration_limit, + calibration_seq_length, + pad_calibration_inputs, + ) -> "MultiInput": input_recorder = InputRecorder( model, tokenizer, @@ -264,9 +276,9 @@ def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibrati ) inputs = input_recorder.get_recorded_inputs() assert inputs is not None, ( - f"No inputs were collected, use a task other than {calibration_tasks}, "+ - f"use option pad_calibration_inputs, or decrease calibration_sequence_length (currently "+ - f"{calibration_seq_length})" + f"No inputs were collected, use a task other than {calibration_tasks}, " + + "use option pad_calibration_inputs, or decrease calibration_sequence_length (currently " + + f"{calibration_seq_length})" ) print(f"Obtained {len(inputs[0].values)} calibration samples") return inputs @@ -283,7 +295,14 @@ def create_quantized_state_dict( calibration_seq_length, pad_calibration_inputs, ) -> "StateDict": - inputs = GPTQQuantHandler.get_inputs(self.mod, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) + inputs = GPTQQuantHandler.get_inputs( + self.mod, + tokenizer, + calibration_tasks, + calibration_limit, + calibration_seq_length, + pad_calibration_inputs, + ) print("Tracing model for GPTQ") GPTQ_runner = GenericGPTQRunner( self.mod, @@ -297,7 +316,7 @@ def create_quantized_state_dict( self.dequantize_func, self.combine_qparams_list_func, self.make_names_and_values_dict_func, - self.skip_layer_func + self.skip_layer_func, ) print("Applying GPTQ to weights") @@ -307,15 +326,22 @@ def create_quantized_state_dict( def convert_for_runtime(self) -> "nn.Module": pass + ##### Weight-only int8 per-channel quantized code ###### + def replace_linear_weight_only_int8_per_channel(module): for name, child in module.named_children(): if isinstance(child, nn.Linear): - setattr(module, name, WeightOnlyInt8Linear(child.in_features, child.out_features)) + setattr( + module, + name, + WeightOnlyInt8Linear(child.in_features, child.out_features), + ) else: replace_linear_weight_only_int8_per_channel(child) + class WeightOnlyInt8QuantHandler: def __init__(self, mod): self.mod = mod @@ -325,7 +351,9 @@ def create_quantized_state_dict(self): cur_state_dict = self.mod.state_dict() for fqn, mod in self.mod.named_modules(): if isinstance(mod, torch.nn.Linear): - int8_weight, scales, _ = dynamically_quantize_per_channel(mod.weight.float(), -128, 127, torch.int8) + int8_weight, scales, _ = dynamically_quantize_per_channel( + mod.weight.float(), -128, 127, torch.int8 + ) cur_state_dict[f"{fqn}.weight"] = int8_weight cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype) @@ -337,58 +365,89 @@ def convert_for_runtime(self): class WeightOnlyInt8Linear(torch.nn.Module): - __constants__ = ['in_features', 'out_features'] + __constants__ = ["in_features", "out_features"] in_features: int out_features: int weight: torch.Tensor - def __init__(self, in_features: int, out_features: int, bias: bool = True, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.in_features = in_features self.out_features = out_features - self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8)) + self.register_buffer( + "weight", torch.empty((out_features, in_features), dtype=torch.int8) + ) self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) def forward(self, input: torch.Tensor) -> torch.Tensor: return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales + ##### weight only int4 per channel groupwise quantized code ###### + def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles): weight_int32, scales_and_zeros = group_quantize_tensor( weight_bf16, n_bit=4, groupsize=groupsize ) - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + weight_int32, inner_k_tiles + ) return weight_int4pack, scales_and_zeros def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) - c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros) + c = torch.ops.aten._weight_int4pack_mm( + x, weight_int4pack, groupsize, scales_and_zeros + ) new_shape = origin_x_size[:-1] + (out_features,) c = c.reshape(new_shape) return c -def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1): +def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1): return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0 + def replace_linear_int4(module, groupsize, inner_k_tiles, padding): for name, child in module.named_children(): if isinstance(child, nn.Linear): if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles): - setattr(module, name, WeightOnlyInt4Linear( - child.in_features, child.out_features, bias=False, - groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False, - )) + setattr( + module, + name, + WeightOnlyInt4Linear( + child.in_features, + child.out_features, + bias=False, + groupsize=groupsize, + inner_k_tiles=inner_k_tiles, + padding=False, + ), + ) elif padding: - setattr(module, name, WeightOnlyInt4Linear( - child.in_features, child.out_features, bias=False, - groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True, - )) + setattr( + module, + name, + WeightOnlyInt4Linear( + child.in_features, + child.out_features, + bias=False, + groupsize=groupsize, + inner_k_tiles=inner_k_tiles, + padding=True, + ), + ) else: replace_linear_int4(child, groupsize, inner_k_tiles, padding) @@ -403,11 +462,11 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): assert inner_k_tiles in [2, 4, 8] @torch.no_grad() - def create_quantized_state_dict(self, use_cuda = True): + def create_quantized_state_dict(self, use_cuda=True): if use_cuda: - device="cuda" + device = "cuda" else: - device="cpu" + device = "cpu" cur_state_dict = self.mod.state_dict() for fqn, mod in self.mod.named_modules(): @@ -419,22 +478,35 @@ def create_quantized_state_dict(self, use_cuda = True): print(f"linear: {fqn}, in={in_features}, out={out_features}") weight = mod.weight.data - if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles): + if not _check_linear_int4_k( + in_features, self.groupsize, self.inner_k_tiles + ): if self.padding: from model import find_multiple import torch.nn.functional as F - print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0") + + print( + f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" + ) padded_in_features = find_multiple(in_features, 1024) - weight = F.pad(weight, pad=(0, padded_in_features - in_features)) + weight = F.pad( + weight, pad=(0, padded_in_features - in_features) + ) else: - print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + - "and that groupsize and inner_k_tiles*16 evenly divide into it") + print( + f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + + "and that groupsize and inner_k_tiles*16 evenly divide into it" + ) continue - weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros( - weight.to(torch.bfloat16).to(device=device), self.groupsize, self.inner_k_tiles + weight_int4pack, scales_and_zeros = ( + prepare_int4_weight_and_scales_and_zeros( + weight.to(torch.bfloat16).to(device=device), + self.groupsize, + self.inner_k_tiles, + ) ) - cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu') - cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu') + cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu") + cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu") return cur_state_dict @@ -442,58 +514,78 @@ def convert_for_runtime(self): replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding) return self.mod + class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler): def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): from model import find_multiple + self.mod = mod self.groupsize = groupsize self.inner_k_tiles = inner_k_tiles self.padding = padding self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize) - self.quantize_func = lambda w, qparams: \ - group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize) - self.dequantize_func = lambda q, qparams: \ - group_dequantize_tensor_from_qparams(q, qparams[0], qparams[1], 4, groupsize).float() - self.combine_qparams_list_func = lambda qparams_list: \ - [torch.cat(x, dim=1) for x in zip(*qparams_list)] + self.quantize_func = lambda w, qparams: group_quantize_tensor_from_qparams( + w, qparams[0], qparams[1], 4, groupsize + ) + self.dequantize_func = lambda q, qparams: group_dequantize_tensor_from_qparams( + q, qparams[0], qparams[1], 4, groupsize + ).float() + self.combine_qparams_list_func = lambda qparams_list: [ + torch.cat(x, dim=1) for x in zip(*qparams_list) + ] # skip unless padding=True or its correctly sized self.skip_layer_func = lambda linear_weight: not ( - _check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding + _check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) + or padding ) + # we need to do the padding here, both for q and the qparams if necessary def make_names_and_values_dict_func(q, qparams): k = q.shape[1] new_k = find_multiple(k, 1024) # how much we need to pad the weight delta_k = new_k - q.shape[1] - final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles) + final_q = torch.ops.aten._convert_weight_to_int4pack( + F.pad(q, pad=(0, delta_k)), inner_k_tiles + ) scales_and_zeros = pack_scales_and_zeros(*qparams) # how many new groups we need for padded weight delta_groups = new_k // groupsize - scales_and_zeros.shape[0] - final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1) + final_s_and_z = F.pad( + scales_and_zeros, pad=(0, 0, 0, 0, 0, delta_groups), value=1 + ) return {"weight": final_q, "scales_and_zeros": final_s_and_z} + self.make_names_and_values_dict_func = make_names_and_values_dict_func super().__init__() - def convert_for_runtime(self): replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding) return self.mod + class WeightOnlyInt4Linear(torch.nn.Module): - __constants__ = ['in_features', 'out_features'] + __constants__ = ["in_features", "out_features"] in_features: int out_features: int weight: torch.Tensor def __init__( - self, in_features: int, out_features: int, - bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True, + self, + in_features: int, + out_features: int, + bias=True, + device=None, + dtype=None, + groupsize: int = 128, + inner_k_tiles: int = 8, + padding: bool = True, ) -> None: super().__init__() self.padding = padding if padding: from model import find_multiple + self.origin_in_features = in_features in_features = find_multiple(in_features, 1024) @@ -504,30 +596,42 @@ def __init__( self.inner_k_tiles = inner_k_tiles assert out_features % 8 == 0, "require out_features % 8 == 0" - assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" + assert ( + in_features % (inner_k_tiles * 16) == 0 + ), "require in_features % (innerKTiles * 16) == 0" self.register_buffer( "weight", - torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) + torch.empty( + ( + out_features // 8, + in_features // (inner_k_tiles * 16), + 32, + inner_k_tiles // 2, + ), + dtype=torch.int32, + ), ) self.register_buffer( "scales_and_zeros", - torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16) + torch.empty( + (in_features // groupsize, out_features, 2), dtype=torch.bfloat16 + ), ) def forward(self, input: torch.Tensor) -> torch.Tensor: input = input.to(torch.bfloat16) if self.padding: import torch.nn.functional as F + input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) return linear_forward_int4( - input, - self.weight, self.scales_and_zeros, self.out_features, self.groupsize + input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize ) def quantize( checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), - mode: str = 'int8', + mode: str = "int8", # following arguments only available when setting int4 quantization. groupsize: int = 128, # following arguments only used for GPTQ @@ -535,45 +639,51 @@ def quantize( calibration_limit: int = 1000, calibration_seq_length: int = 100, pad_calibration_inputs: bool = False, - percdamp: float = .01, + percdamp: float = 0.01, blocksize: int = 128, - label: str = '', + label: str = "", ) -> None: assert checkpoint_path.is_file(), checkpoint_path - device = 'cpu' + device = "cpu" precision = torch.bfloat16 print("Loading model ...") t0 = time.time() - with torch.device('meta'): + with torch.device("meta"): model = Transformer.from_name(checkpoint_path.parent.name) checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) model.load_state_dict(checkpoint, assign=True) model = model.to(dtype=precision, device=device) - if mode == 'int8': - print("Quantizing model weights for int8 weight-only symmetric per-channel quantization") + if mode == "int8": + print( + "Quantizing model weights for int8 weight-only symmetric per-channel quantization" + ) quant_handler = WeightOnlyInt8QuantHandler(model) quantized_state_dict = quant_handler.create_quantized_state_dict() dir_name = checkpoint_path.parent base_name = checkpoint_path.name - new_base_name = base_name.replace('.pth', f'{label}int8.pth') + new_base_name = base_name.replace(".pth", f"{label}int8.pth") - elif mode == 'int4': - print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization") + elif mode == "int4": + print( + "Quantizing model weights for int4 weight-only affine per-channel groupwise quantization" + ) quant_handler = WeightOnlyInt4QuantHandler(model, groupsize) quantized_state_dict = quant_handler.create_quantized_state_dict() dir_name = checkpoint_path.parent base_name = checkpoint_path.name - new_base_name = base_name.replace('.pth', f"{label}int4.g{groupsize}.pth") + new_base_name = base_name.replace(".pth", f"{label}int4.g{groupsize}.pth") - elif mode == 'int4-gptq': - print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization using GPTQ...") + elif mode == "int4-gptq": + print( + "Quantizing model weights for int4 weight-only affine per-channel groupwise quantization using GPTQ..." + ) quant_handler = WeightOnlyInt4GPTQQuantHandler(model, groupsize) tokenizer_path = checkpoint_path.parent / "tokenizer.model" @@ -588,35 +698,89 @@ def quantize( calibration_tasks, calibration_limit, calibration_seq_length, - pad_calibration_inputs + pad_calibration_inputs, ) dir_name = checkpoint_path.parent base_name = checkpoint_path.name - new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.pth") + new_base_name = base_name.replace(".pth", f"{label}int4-gptq.g{groupsize}.pth") else: - raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]") + raise ValueError( + f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]" + ) quantize_path = dir_name / new_base_name print(f"Writing quantized weights to {quantize_path}") - quantize_path.unlink(missing_ok=True) # remove existing file if one already there + quantize_path.unlink(missing_ok=True) # remove existing file if one already there torch.save(quantized_state_dict, quantize_path) print(f"Quantization complete took {time.time() - t0:.02f} seconds") return -if __name__ == '__main__': + +if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Quantize a model.') - parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.') - parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform') - parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.') - parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq') - parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration') - parser.add_argument('--calibration_seq_length', type=int, default=100, help='length of sequences to use for gptq calibration') - parser.add_argument('--pad_calibration_inputs', type=bool, default=False, help='pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower') - parser.add_argument('--percdamp', type=float, default=.01, help='gptq percentage dampening') - parser.add_argument('--blocksize', type=int, default=128, help='blocksize for gptq') - parser.add_argument('--label', type=str, default='_', help='label to add to output filename') + + parser = argparse.ArgumentParser(description="Quantize a model.") + parser.add_argument( + "--checkpoint_path", + type=Path, + default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), + help="Path to the model checkpoint to be quantized.", + ) + parser.add_argument( + "--mode", + "-q", + type=str, + default="int8", + choices=["int8", "int4", "int4-gptq"], + help="type of quantization to perform", + ) + parser.add_argument( + "--groupsize", type=int, default=32, help="Group size for int4 quantization." + ) + parser.add_argument( + "--calibration_tasks", + type=str, + nargs="+", + default=["wikitext"], + help="tasks to do gptq calibration on, if doing gptq", + ) + parser.add_argument( + "--calibration_limit", + type=int, + default=1000, + help="number of samples to use for gptq calibration", + ) + parser.add_argument( + "--calibration_seq_length", + type=int, + default=100, + help="length of sequences to use for gptq calibration", + ) + parser.add_argument( + "--pad_calibration_inputs", + type=bool, + default=False, + help="pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower", + ) + parser.add_argument( + "--percdamp", type=float, default=0.01, help="gptq percentage dampening" + ) + parser.add_argument("--blocksize", type=int, default=128, help="blocksize for gptq") + parser.add_argument( + "--label", type=str, default="_", help="label to add to output filename" + ) args = parser.parse_args() - quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label) + quantize( + args.checkpoint_path, + args.mode, + args.groupsize, + args.calibration_tasks, + args.calibration_limit, + args.calibration_seq_length, + args.pad_calibration_inputs, + args.percdamp, + args.blocksize, + args.label, + ) diff --git a/requirements.txt b/requirements.txt index 04f828c..deeba0c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,15 @@ -torch +absl-py +accelerate +bert-score +blobfile +claudette +datasets +evaluate +huggingface_hub +nltk sentencepiece +rouge-score +ruff +scikit-learn tiktoken +git+https://github.com/google-research/bleurt.git diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index 8a22106..be1df09 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import json +import os import re import shutil import sys @@ -11,6 +12,8 @@ from typing import Optional import torch +from safetensors.torch import load_file +from huggingface_hub import hf_hub_download # support running without installing as a package wd = Path(__file__).parent.parent.resolve() @@ -22,9 +25,16 @@ @torch.inference_mode() def convert_hf_checkpoint( *, - checkpoint_dir: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf"), + checkpoint_dir: Path = Path( + "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf" + ), model_name: Optional[str] = None, ) -> None: + out_model_path = checkpoint_dir / "model.pth" + if os.path.exists(out_model_path): + print(f"Model already exists at {out_model_path}") + return + if model_name is None: model_name = checkpoint_dir.name @@ -34,7 +44,7 @@ def convert_hf_checkpoint( # weights is state dict are the same in each consolidated.NN.pth file. Thus, it is not # currently supported. # Along this, we need to copy the original/tokenizer.model file to tokenizer.model.tiktoken - is_llama3 = "Llama-3" in model_name + is_llama3 = "Llama-3" in model_name and "gist" not in model_name if is_llama3: # Check if we have multiple original/consolidated.NN.pth files and report error # if we do for Llama 3. @@ -44,8 +54,8 @@ def convert_hf_checkpoint( if len(bin_files) > 1: raise ValueError( f"Multiple consolidated.NN.pth files found in {original_dir}. " - "Merging them into one model.pth file is not supported for Llama 3.") - + "Merging them into one model.pth file is not supported for Llama 3." + ) config = ModelArgs.from_name(model_name) print(f"Model config {config.__dict__}") @@ -53,20 +63,35 @@ def convert_hf_checkpoint( # Load the json file containing weight mapping if not is_llama3: model_map_json = checkpoint_dir / "pytorch_model.bin.index.json" - - assert model_map_json.is_file() - - with open(model_map_json) as json_map: - bin_index = json.load(json_map) - + + if not model_map_json.is_file(): + model_map_json = checkpoint_dir / "model.safetensors.index.json" + + if model_map_json.is_file(): + # For larger models, the weights are stored in separate files, so we need to load the index. + with open(model_map_json) as json_map: + bin_index = json.load(json_map) + bin_files = { + checkpoint_dir / bin for bin in bin_index["weight_map"].values() + } + else: + # For smaller models, the weights are stored in a single file. + # Note it could be a bin file or a safetensors file. + if (checkpoint_dir / "pytorch_model.bin").exists(): + bin_files = {checkpoint_dir / "pytorch_model.bin"} + else: + bin_files = {checkpoint_dir / "model.safetensors"} weight_map = { "model.embed_tokens.weight": "tok_embeddings.weight", "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", - 'model.layers.{}.self_attn.rotary_emb.inv_freq': None, - 'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight', + "model.layers.{}.self_attn.q_proj.bias": "layers.{}.attention.wq.bias", + "model.layers.{}.self_attn.k_proj.bias": "layers.{}.attention.wk.bias", + "model.layers.{}.self_attn.v_proj.bias": "layers.{}.attention.wv.bias", + "model.layers.{}.self_attn.rotary_emb.inv_freq": None, + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", @@ -74,7 +99,6 @@ def convert_hf_checkpoint( "model.norm.weight": "norm.weight", "lm_head.weight": "output.weight", } - bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} else: # There is no separate pytorch_model.bin.index.json file for llama3. # Instead, we will just use all original/consolidated.NN.pth files. @@ -83,10 +107,8 @@ def convert_hf_checkpoint( original_dir = checkpoint_dir / "original" pattern = re.compile(r"^consolidated\.\d{2}\.pth$") bin_files = {bin for bin in original_dir.iterdir() if pattern.match(bin.name)} - - def permute(w, n_head): - dim = config.dim + def permute(w, n_head, dim=config.dim): return ( w.view(n_head, 2, config.head_dim // 2, dim) .transpose(1, 2) @@ -95,14 +117,19 @@ def permute(w, n_head): merged_result = {} for file in sorted(bin_files): - state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) + if str(file).endswith(".safetensors"): + state_dict = load_file(str(file)) + else: + state_dict = torch.load( + str(file), map_location="cpu", mmap=True, weights_only=True + ) merged_result.update(state_dict) final_result = {} if weight_map is not None: for key, value in merged_result.items(): if "layers" in key: - abstract_key = re.sub(r'(\d+)', '{}', key) - layer_num = re.search(r'\d+', key).group(0) + abstract_key = re.sub(r"(\d+)", "{}", key) + layer_num = re.search(r"\d+", key).group(0) new_key = weight_map[abstract_key] if new_key is None: continue @@ -117,31 +144,55 @@ def permute(w, n_head): q = final_result[key] k = final_result[key.replace("wq", "wk")] v = final_result[key.replace("wq", "wv")] - q = permute(q, config.n_head) - k = permute(k, config.n_local_heads) + if key.endswith("weight"): + q = permute(q, config.n_head) + k = permute(k, config.n_local_heads) + else: + # Permute bias to be compatible with the weight permutation + q = permute(q, config.n_head, dim=1).view(-1) + k = permute(k, config.n_local_heads, dim=1).view(-1) final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) del final_result[key] del final_result[key.replace("wq", "wk")] del final_result[key.replace("wq", "wv")] + if "output.weight" not in final_result: + # lm_head.weight may not be explicitly stored in the HF checkpoint if input and output embeddings are shared + final_result["output.weight"] = final_result[ + "tok_embeddings.weight" + ].clone() else: final_result = merged_result - print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") - torch.save(final_result, checkpoint_dir / "model.pth") if is_llama3: original_dir = checkpoint_dir / "original" tokenizer_model = original_dir / "tokenizer.model" tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model" print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}") shutil.copy(tokenizer_model, tokenizer_model_tiktoken) + elif "Llama-3" in model_name: # Can be one of the finetunes of Llama-3 + path = hf_hub_download( + repo_id = "meta-llama/Meta-Llama-3-8B", + filename="tokenizer.model", + subfolder="original", + ) + shutil.copy(path, checkpoint_dir / "tokenizer.model") + + print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") + torch.save(final_result, out_model_path) + -if __name__ == '__main__': +if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.') - parser.add_argument('--checkpoint_dir', type=Path, default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf")) - parser.add_argument('--model_name', type=str, default=None) + + parser = argparse.ArgumentParser(description="Convert HuggingFace checkpoint.") + parser.add_argument( + "--checkpoint_dir", + type=Path, + default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf"), + ) + parser.add_argument("--model_name", type=str, default=None) args = parser.parse_args() convert_hf_checkpoint( checkpoint_dir=args.checkpoint_dir, model_name=args.model_name, - ) + ) \ No newline at end of file diff --git a/scripts/download.py b/scripts/download.py index a968cf3..3d7a158 100644 --- a/scripts/download.py +++ b/scripts/download.py @@ -11,20 +11,45 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None: from huggingface_hub import snapshot_download + os.makedirs(f"checkpoints/{repo_id}", exist_ok=True) + + # if directory is not empty, don't download + if os.listdir(f"checkpoints/{repo_id}"): + print( + f'Directory checkpoints/{repo_id} is not empty, skipping download. First, "rm -rf checkpoints/{repo_id}" if you want to re-download.' + ) + return + try: - snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token) + snapshot_download( + repo_id, + local_dir=f"checkpoints/{repo_id}", + local_dir_use_symlinks=False, + token=hf_token, + ) except HTTPError as e: if e.response.status_code == 401: - print("You need to pass a valid `--hf_token=...` to download private checkpoints.") + print( + "You need to pass a valid `--hf_token=...` to download private checkpoints." + ) else: raise e -if __name__ == '__main__': + +if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Download data from HuggingFace Hub.') - parser.add_argument('--repo_id', type=str, default="checkpoints/meta-llama/llama-2-7b-chat-hf", help='Repository ID to download from.') - parser.add_argument('--hf_token', type=str, default=None, help='HuggingFace API token.') + + parser = argparse.ArgumentParser(description="Download data from HuggingFace Hub.") + parser.add_argument( + "--repo_id", + type=str, + default="karpathy/tinyllamas", + help="Repository ID to download from.", + ) + parser.add_argument( + "--hf_token", type=str, default=None, help="HuggingFace API token." + ) args = parser.parse_args() hf_download(args.repo_id, args.hf_token) diff --git a/scripts/prepare.sh b/scripts/prepare.sh index 43a0baa..600ccac 100755 --- a/scripts/prepare.sh +++ b/scripts/prepare.sh @@ -1 +1 @@ -python scripts/download.py --repo_id $1 && python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$1 && python quantize.py --checkpoint_path checkpoints/$1/model.pth --mode int8 +python scripts/download.py --repo_id $1 && python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$1 diff --git a/scripts/prepare_llama3.sh b/scripts/prepare_llama3.sh new file mode 100644 index 0000000..446ad08 --- /dev/null +++ b/scripts/prepare_llama3.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +set -e + +# Use env vars if they exist, otherwise set defaults +: "${HF:=meta-llama/Meta-Llama-3-8B-Instruct}" +: "${TRUNC_LAYERS:=4}" + +# Export the variables +export HF +export TRUNC_LAYERS + + +python scripts/download.py --repo_id $HF +python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$HF +python scripts/truncate.py --checkpoint_dir checkpoints/$HF --trunc_layers $TRUNC_LAYERS diff --git a/scripts/truncate.py b/scripts/truncate.py new file mode 100644 index 0000000..01b66d9 --- /dev/null +++ b/scripts/truncate.py @@ -0,0 +1,65 @@ +import os +from pathlib import Path +import re +import shutil + +import argparse +import torch + + +def keep_weight(key, trunc_layers): + if "layers" in key: + layer_num = int(re.search(r"layers\.(\d+)", key).group(1)) + return layer_num < trunc_layers + else: + return True + + +def main(): + parser = argparse.ArgumentParser( + description="Script to remove all but the first K layers of model. For Debugging." + ) + parser.add_argument( + "--checkpoint_dir", + type=Path, + default=Path("checkpoints/meta-llama/Meta-Llama-3-8B-Instruct"), + ) + parser.add_argument("--trunc_layers", type=int, default=4) + + # Add more arguments if needed + args = parser.parse_args() + + trunc_dir = args.checkpoint_dir.with_name( + args.checkpoint_dir.name + f"-{args.trunc_layers}-Layers" + ) + os.makedirs(trunc_dir, exist_ok=True) + + # If trunc_dir has tokenizer.model file, exit without error + if (trunc_dir / "tokenizer.model").exists(): + print(f"Truncated model already exists at {trunc_dir}. Exiting without error.") + exit(0) + + # Copy tokenizer.model file to trunc_dir + shutil.copy(args.checkpoint_dir / "tokenizer.model", trunc_dir / "tokenizer.model") + + weights = torch.load(args.checkpoint_dir / "model.pth", map_location="cpu") + new_weights = dict( + { + key: value + for key, value in weights.items() + if keep_weight(key, args.trunc_layers) + } + ) + torch.save(new_weights, trunc_dir / "model.pth") + + orig_size = sum([value.numel() for value in weights.values()]) + + new_size = sum([value.numel() for value in new_weights.values()]) + + print( + f"Reduced number of parameters from {orig_size} to {new_size} by truncating to first {args.trunc_layers} layers." + ) + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index a4ea6ea..6c879ef 100644 --- a/setup.py +++ b/setup.py @@ -6,14 +6,14 @@ from setuptools import setup, find_packages setup( - name='gpt-fast', - version='0.1', + name="gpt-fast", + version="0.1", packages=find_packages(), install_requires=[ - 'torch', + "torch", ], - description='A simple, fast, pure PyTorch Llama inference engine', - long_description=open('README.md').read(), - long_description_content_type='text/markdown', - url='https://github.com/pytorch-labs/gpt-fast', + description="A simple, fast, pure PyTorch Llama inference engine", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + url="https://github.com/pytorch-labs/gpt-fast", ) diff --git a/summarize.py b/summarize.py new file mode 100644 index 0000000..5d4d1a2 --- /dev/null +++ b/summarize.py @@ -0,0 +1,258 @@ +import argparse +import os + +from claudette import models as anthropic_models, Chat + +from datasets import Dataset +import pandas as pd +import torch +from data_utils import BENCHMARKS +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM + +assert ( + "ANTHROPIC_API_KEY" in os.environ +), "Please set the ANTHROPIC_API_KEY environment variable." + + +PROMPT_TEMPLATES = { + "triviaqa": "Compress the information in the retrieved documents into a 1-3 sentences " + "such that it includes only information relevant to answering the question: %s\n\nRetrieved Documents:\n%s", + "dolomites": "Compress the information in the instructions into a 1-3 sentences " + "such that it includes only information relevant to completing the task: %s\n\nInstructions:\n%s", +} + +SUMMARY_PREFILL = { + "triviaqa": "Compressed Documents: ", + "dolomites": "Compressed Instructions: ", +} + +SCORE_PREFILL = { + "triviaqa": "The answer is", + "dolomites": "The completed task is", +} + +LONGCONTEXT_DATASETS = ["dolomites"] # no RAG will be the same as original + +# We will evaluate each summary based on if it improves downstream performance: +# p(answer|question, summarized context) minus either p(answer|original context) or p(answer|question) +SCORER_MODEL = ( + "microsoft/Phi-3-mini-128k-instruct" # "meta-llama/Meta-Llama-3-8B-Instruct" +) + + +def compute_likelihoods( + args, + ctx, + q, + answers, + instruction, + tokenizer, + scorer, + max_ctx_len=None, + batch_size=8, +): + # We need enough tokens for the instruction, question, and answer (subtract max context by 1024 to be safe) + max_ctx_len = max_ctx_len or tokenizer.model_max_length - 1024 + ctx_tokens = len(tokenizer.encode(ctx)) + if ctx_tokens > max_ctx_len: + ctx_words = ctx.split(" ") + keep_num_words = round(len(ctx_words) * max_ctx_len / ctx_tokens) + ctx = " ".join(ctx_words[:keep_num_words]) + + chat = [ + {"role": "user", "content": f"{ctx}\n\n{instruction}\n\n{q}"}, + ] + + prompt = ( + tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + + SCORE_PREFILL[args.dataset] + ) + prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) + + n = len(prompt_ids) + + ans_ids = [ + tokenizer.encode(" " + ans.strip(), add_special_tokens=False) for ans in answers + ] + input_ids = [prompt_ids + ans_id for ans_id in ans_ids] + max_len = max(map(len, input_ids)) + + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + + # Pad the input_ids up to max_len + assert tokenizer.padding_side in {"right", "left"} + input_ids = torch.LongTensor( + [ + ids + [tokenizer.pad_token_id] * (max_len - len(ids)) + if tokenizer.padding_side == "right" + else [tokenizer.pad_token_id] * (max_len - len(ids)) + ids + for ids in input_ids + ] + ) + input_ids = input_ids.to(scorer.device) + + loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="mean") + + labels = input_ids.clone() + # No Loss on the prompt + labels[:, :n] = -100 + is_pad = input_ids == tokenizer.pad_token_id + labels[is_pad] = -100 + shift_labels = labels[..., 1:].contiguous() + + j = 0 + scores = [] + for batch_idx in range(0, len(input_ids), batch_size): + with torch.no_grad(): + logits = scorer(input_ids[batch_idx : batch_idx + batch_size]).logits + shift_logits = logits[..., :-1, :].contiguous() + for i in range(len(shift_logits)): + _shift_logits = shift_logits[i].view(-1, scorer.config.vocab_size) + _shift_labels = shift_labels[j].view(-1) + ll = (-loss_fct(_shift_logits, _shift_labels)).item() + scores.append(ll) + j += 1 + + return {"max": max(scores), "mean": sum(scores) / len(scores)} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--name", type=str, default="haiku", choices=["haiku", "sonnet", "opus"] + ) + parser.add_argument("--dataset", type=str, default="triviaqa") + parser.add_argument( + "--min_toks", + type=int, + default=50, + help="If context has < min_toks, there is no need to summarize it. The summary is itself.", + ) + parser.add_argument( + "--batch_size", type=int, default=1, help="Max batch size for the scorer." + ) + parser.add_argument( + "--max_ctx_len", + type=int, + default=None, + help="An optional override to max context window of model.", + ) + parser.add_argument( + "--out_dir", + type=str, + default=".", + help="The directory to save the summarized dataset.", + ) + + args = parser.parse_args() + + model = [m for m in anthropic_models if args.name in str(m)][0] + + dataset = BENCHMARKS[args.dataset]() + + train = dataset.get_train() + + scorer = ( + AutoModelForCausalLM.from_pretrained( + SCORER_MODEL, + torch_dtype="auto", + attn_implementation="flash_attention_2", + trust_remote_code=True, + ) + .eval() + .to("cuda") + ) + + tokenizer = AutoTokenizer.from_pretrained(SCORER_MODEL) + + stats = [] + out_data = [] + + for ex in tqdm(train): + # Claudette is stateful so you need to re-instantiate each time + chat = Chat(model, sp="""You are a helpful and concise assistant.""") + + q, ctx = dataset.question(ex), dataset.context(ex) + ctx_toks = len(ctx.split(" ")) + + if ctx_toks < args.min_toks and args.dataset not in LONGCONTEXT_DATASETS: + summary = ctx + else: + prompt = PROMPT_TEMPLATES[args.dataset] % (q, ctx) + + summary = ( + chat(prompt, prefill=SUMMARY_PREFILL[args.dataset]).content[0].text + ) + + assert summary.startswith(SUMMARY_PREFILL[args.dataset]) + summary = summary[len(SUMMARY_PREFILL[args.dataset]) :].strip() + + answers = dataset.answer(ex) + + if type(answers) == str: + answers = [answers] + + original_scores = compute_likelihoods( + args, + ctx, + q, + answers, + dataset.instruction(), + tokenizer, + scorer, + max_ctx_len=args.max_ctx_len, + batch_size=args.batch_size, + ) + compressed_scores = compute_likelihoods( + args, + summary, + q, + answers, + dataset.instruction(), + tokenizer, + scorer, + max_ctx_len=args.max_ctx_len, + batch_size=args.batch_size, + ) + + if args.dataset not in LONGCONTEXT_DATASETS: + no_rag_scores = compute_likelihoods( + args, + "", + q, + answers, + dataset.instruction(), + tokenizer, + scorer, + max_ctx_len=args.max_ctx_len, + batch_size=args.batch_size, + ) + + else: + no_rag_scores = original_scores + + stats.append( + { + "original_mean": original_scores["mean"], + "original_max": original_scores["max"], + "compressed_mean": compressed_scores["mean"], + "compressed_max": compressed_scores["max"], + "no_rag_mean": no_rag_scores["mean"], + "no_rag_max": no_rag_scores["max"], + } + ) + + print("Running statistics...") + print(pd.DataFrame(stats).mean()) + + out_row = ex.copy() + out_row.update({"summary_prompt": prompt, "summary": summary}) + out_data.append(out_row) + + dataset = Dataset.from_list(out_data) + out_hf_path = os.path.join(args.out_dir, f"{args.dataset}_summarized") + print(f"Saving data with downstream scores to {out_hf_path}") + dataset.save_to_disk(out_hf_path) diff --git a/task.py b/task.py new file mode 100644 index 0000000..26dc0c9 --- /dev/null +++ b/task.py @@ -0,0 +1,819 @@ +import random +from abc import ABC, abstractmethod +from string import ascii_uppercase +from pathlib import Path + +import numpy as np +import pandas as pd +from datasets import load_dataset + +from metric import AutoMetric +from tokenizer import get_tokenizer + + +class EvaluationTask(ABC): + train_split: str = "train" + validation_split: str = "validation" + test_split: str = "test" + mandatory_cols = ["context", "question", "prompt", "labels"] + requires_logits = False + + def __init__( + self, + prompt_template, + max_tokens, + model_max_length, + tokenizer, + hf_args=None, + **kwargs, + ): + self.prompt_template = prompt_template + self.max_tokens = max_tokens + self.model_max_length = model_max_length + self.tokenizer = tokenizer + self.hf_args = hf_args + self.num_samples = kwargs.pop("num_samples", None) + + # Download the dataset + self._download() + + # Lazy process each split as needed + self.is_ready = { + self.train_split: False, + self.validation_split: False, + self.test_split: False, + } + + def _download(self): + # Can over-write if not using HF + self.dataset = load_dataset(*self.hf_args) + + def get_split(self, split): + remove_cols = [ + col + for col in self.dataset[split].column_names + if col not in self.mandatory_cols + ] + if not self.is_ready[split]: + split_data = self.dataset[split] + split_data = split_data.map( + self.prepare_batch, batched=True, remove_columns=remove_cols + ) + + # Filter out examples that could be too long for the model + filtered_data = split_data.filter( + lambda x: len(self.tokenizer(x["prompt"])) + self.max_tokens + <= self.model_max_length + ) + print( + f"Filtered {len(split_data) - len(filtered_data)} examples from split {split}" + ) + + if self.num_samples is not None and len(filtered_data) > self.num_samples: + n = min(self.num_samples, len(filtered_data)) + print(f"Randomly sample {n} examples") + # Use a fixed seed for reproducibility + inds = random.Random(n).sample(range(len(filtered_data)), n) + filtered_data = filtered_data.select(inds) + + self.dataset[split] = filtered_data + self.is_ready[split] = True + + return self.dataset[split] + + def get_train(self): + return self.get_split(self.train_split) + + def get_validation(self): + return self.get_split(self.validation_split) + + def get_test(self): + return self.get_split(self.test_split) + + def compute_metrics(self, predictions, split, dataset): + assert self.is_ready[split], f"Split {split} has not been processed yet." + assert ( + len(dataset) == len(predictions) + ), f"Number of predictions and labels must match ({len(predictions)} != {len(dataset)})." + return self._compute_metrics(dataset["prompt"], predictions, dataset["labels"]) + + def _compute_metrics( + self, prompts: list, predictions: list, labels: list[str | list[str]] + ): + return { + metric_name: metric.compute(prompts, predictions, labels) + for metric_name, metric in self.metrics.items() + } + + def train_metrics(self, predictions): + return self.compute_metrics(predictions, self.train_split, self.get_train()) + + def validation_metrics(self, predictions): + return self.compute_metrics( + predictions, self.validation_split, self.get_validation() + ) + + def test_metrics(self, predictions): + return self.compute_metrics(predictions, self.test_split, self.get_test()) + + def prepare_batch(self, batch): + keys = list(batch.keys()) + n = len(batch[keys[0]]) + processed = {k: [] for k in self.mandatory_cols} + for i in range(n): + row = {k: v[i] for k, v in batch.items()} + out = {k: None for k in self.mandatory_cols} + out = self.prepare_row(row) + # Most tasks will return a single dictionary example from a single row + if type(out) != list: + out = [out] + for x in out: + for k in self.mandatory_cols: + processed[k].append(x.get(k, None)) + return processed + + @abstractmethod + def prepare_row(self, row) -> dict | list[dict]: + """Process a single row from the dataset.""" + pass + + +class LogitEvaluationTask(EvaluationTask): + def __init__(self, prompt_template, max_tokens, hf_args=None, **kwargs): + super().__init__(prompt_template, max_tokens, hf_args=hf_args, **kwargs) + self.requires_logits = True + + @abstractmethod + def _process_logits(self, logits, split): + """Process logits and return predictions.""" + pass + + def compute_metrics(self, predictions, split, dataset): + # LogitEvaluationTask will get logits instead of token predictions, so we need to process them first + predictions = self._process_logits(predictions, split) + return super().compute_metrics(predictions, split, dataset) + + +class Squality(EvaluationTask): + DEFAULT_PROMPT_TEMPLATE = """You are given a story and a question. Answer the question in a single paragraph. + +====STORY==== +{story} + +====QUESTION==== +{question}""" + + def __init__( + self, prompt_template=DEFAULT_PROMPT_TEMPLATE, max_tokens=1024, **kwargs + ): + super().__init__( + prompt_template, max_tokens, hf_args=["pszemraj/SQuALITY-v1.3"], **kwargs + ) + + self.metrics = { + "BertScore": AutoMetric.from_name("bertscore"), + "Rouge": AutoMetric.from_name("rouge"), + "LLM-Rouge": AutoMetric.from_name("llm-rouge"), + "LLM-Judge": AutoMetric.from_name("llm-as-a-judge"), + } + + def prepare_row(self, row: dict): + story = row["document"].strip() + questions = row["questions"] + out = [] + for question in questions: + question_text = question["question_text"].strip() + prompt = self.prompt_template.format( + story=story, question=question["question_text"] + ) + labels = [resp["response_text"].strip() for resp in question["responses"]] + out_row = { + "prompt": prompt, + "context": story, + "question": question_text, + "labels": labels, + } + out.append(out_row) + return out + + +class TriviaQA(EvaluationTask): + DEFAULT_PROMPT_TEMPLATE = """You are given a question and potentially relevant context from Wikipedia. Answer the question without any explanation. + +====CONTEXT==== +{context} + +====QUESTION==== +{question}""" + + def __init__( + self, prompt_template=DEFAULT_PROMPT_TEMPLATE, max_tokens=1024, **kwargs + ): + self.use_web = kwargs.pop("use_web", False) + + super().__init__( + prompt_template, max_tokens, hf_args=["trivia_qa", "rc"], **kwargs + ) + + self.metrics = { + "BertScore": AutoMetric.from_name("bertscore"), + "Rouge": AutoMetric.from_name("rouge"), + "LLM-Rouge": AutoMetric.from_name("llm-rouge"), + } + + def prepare_row(self, row: dict): + wikis = row["entity_pages"] + webs = row["search_results"] + + wiki_n = len(wikis["title"]) + web_n = len(webs["title"]) + + contexts = [] + + for i in range(wiki_n): + contexts.append("# " + wikis["title"][i] + "\n" + wikis["wiki_context"][i]) + + if self.use_web: + for j in range(web_n): + contexts.append( + "# " + + webs["title"][j] + + "\n" + + webs["description"][j] + + "\n" + + webs["search_context"][j] + ) + + context_str = "\n\n".join(contexts) + question = row["question"] + labels = row["answer"]["aliases"] + if row["answer"]["value"] not in labels: + labels.append(row["answer"]["value"]) + assert len(labels) > 0 + return { + "context": context_str, + "question": question, + "prompt": self.prompt_template.format( + context=context_str, question=question + ), + "labels": labels, + } + + +class Dolomites(EvaluationTask): + DEFAULT_PROMPT_TEMPLATE = """You need to perform a writing task from the field of {field}. +You are given (1) a task description which contains input and output sections, and (2) an example input for this task, which is a sample of the input sections of the task with concrete details. +You need to generate the output sections for the given example input. + +IMPORTANT: +- Make sure the length of each output section matches the required length and the section headers are exactly the same. +- Make sure the output follows the structure of the output sections in the task description, is factually accurate and detailed. + +====TASK DESCRIPTION==== +{task_description} + +====EXAMPLE INPUT==== +{example_input}""" + + def __init__( + self, prompt_template=DEFAULT_PROMPT_TEMPLATE, max_tokens=1024, **kwargs + ): + super().__init__( + prompt_template, max_tokens, hf_args=["fladhak/dolomites"], **kwargs + ) + + # Dolomites test split does not have references, so we will use validation split for testing + self.test_split = "validation" + + self.metrics = { + "BertScore": AutoMetric.from_name("bertscore"), + "Rouge": AutoMetric.from_name("rouge"), + "LLM-Rouge": AutoMetric.from_name("llm-rouge"), + "LLM-Judge": AutoMetric.from_name("llm-as-a-judge"), + } + + def prepare_row(self, row: dict): + field = row["field"] + task_objective = row["task_objective"] + task_procedure = row["task_procedure"] + task_input = row["task_input"] + task_output = row["task_output"] + task_notes = row["task_notes"] + example_input = row["example_input"] + ref = row["example_output"] + + task_description = f"Task objective: {task_objective}\nTask prodecedure: {task_procedure}\nTask input: {task_input}\nTask output: {task_output}" + if task_notes is not None: + task_description += f"\nAdditional notes: {task_notes}" + + prompt = self.prompt_template.format( + field=field, task_description=task_description, example_input=example_input + ) + + return { + "prompt": prompt, + "field": field, + "context": task_description, + "question": example_input, + "labels": ref, + } + + +class QMSum(EvaluationTask): + DEFAULT_PROMPT_TEMPLATE = """You will be shown a meeting transcipt along with a query. Your task is to carefully read the transcript and provide a concise answer to the query. + +====MEETING TRANSCRIPT==== +{transcript} + +====QUERY==== +{query}""" + + def __init__( + self, prompt_template=DEFAULT_PROMPT_TEMPLATE, max_tokens=1024, **kwargs + ): + super().__init__( + prompt_template, max_tokens, hf_args=["fladhak/qmsum"], **kwargs + ) + + self.metrics = { + "BertScore": AutoMetric.from_name("bertscore"), + "Rouge": AutoMetric.from_name("rouge"), + "LLM-Rouge": AutoMetric.from_name("llm-rouge"), + "LLM-Judge": AutoMetric.from_name("llm-as-a-judge"), + } + + def prepare_row(self, row: dict): + transcript = "\n\n".join( + [f"{x['speaker']}: {x['content']}" for x in row["transcript"]] + ) + query = row["query"] + answer = row["answer"] + + prompt = self.prompt_template.format(transcript=transcript, query=query) + + return { + "prompt": prompt, + "context": transcript, + "labels": answer, + } + + +class Musique(EvaluationTask): + DEFAULT_PROMPT_TEMPLATE = """You will be shown several paragraphs from Wikipedia along with a question. Your task is to carefully read the paragraphs and provide a concise answer to the question. +IMPORTANT: You should only use the infomation provided in the paragraphs to answer the question. + +====PARAGRAPHS==== +{paragraphs} + +====QUESTION==== +{question}""" + + def __init__( + self, prompt_template=DEFAULT_PROMPT_TEMPLATE, max_tokens=128, **kwargs + ): + super().__init__( + prompt_template, max_tokens, hf_args=["fladhak/musique"], **kwargs + ) + + # Musique test split does not have references, so we will use validation split for testing + self.test_split = "validation" + + self.metrics = { + "BertScore": AutoMetric.from_name("bertscore"), + "Rouge": AutoMetric.from_name("rouge"), + "LLM-Rouge": AutoMetric.from_name("llm-rouge"), + "LLM-Judge": AutoMetric.from_name("llm-as-a-judge"), + } + + def prepare_row(self, row: dict): + paragraphs = "\n\n".join( + [f"{x['title']}:\n{x['paragraph_text']}" for x in row["paragraphs"]] + ) + question = row["question"] + answers = [row["answer"]] + row["answer_aliases"] + + prompt = self.prompt_template.format(paragraphs=paragraphs, question=question) + + return { + "prompt": prompt, + "context": paragraphs, + "question": question, + "labels": answers, + } + + +class TruthfulQA(LogitEvaluationTask): + DEFAULT_PROMPT_TEMPLATE = """You will be shown a question along with several possible answers. Please carefully read the question and the answer choices and pick the best answer. +IMPORTANT: You should simply provide the letter corresponding to the answer choice that you picked. You do not need to write out the entire answer or provide any explanation. + +====QUESTION==== +{question} + +====ANSWER CHOICES==== +{choices}""" + + def __init__(self, prompt_template=DEFAULT_PROMPT_TEMPLATE, max_tokens=1, **kwargs): + super().__init__( + prompt_template, + max_tokens, + hf_args=["truthfulqa/truthful_qa", "multiple_choice"], + **kwargs, + ) + + # Musique test split does not have references, so we will use validation split for testing + self.test_split = "validation" + + self.metrics = { + "Accuracy": AutoMetric.from_name("accuracy"), + } + self.mandatory_cols = self.mandatory_cols.copy() + ["num_choices"] + + def prepare_row(self, row: dict): + question = row["question"] + choices = "\n".join( + [ + f"{char}. {opt}" + for char, opt in zip(ascii_uppercase, row["mc1_targets"]["choices"]) + ] + ) + answer = ascii_uppercase[row["mc1_targets"]["labels"].index(1)] + + prompt = self.prompt_template.format(question=question, choices=choices) + + return { + "prompt": prompt, + "question": question, + "context": choices, + "labels": answer, + "num_choices": len(row["mc1_targets"]["choices"]), + } + + def _process_logits(self, logits, split): + preds = [] + for l, nc in zip(logits, self.get_split(split)["num_choices"]): + pred = [l[ascii_uppercase[i]] for i in range(nc)] + preds.append(ascii_uppercase[np.argmax(pred)]) + + return preds + + +class ScrollsQuality(LogitEvaluationTask): + """ + Evaluation dataset derived from `tau/scrolls`. + It is processed into a suitable format here: https://huggingface.co/datasets/rbiswasfc/quality. + Test split doesn't have ground truths, hence it will use validation split as an alternative. + """ + + DEFAULT_PROMPT_TEMPLATE = """You will be given a context, a question related to that context, and four possible answer choices. Carefully read the context, question, and answer choices, then select the best answer. +IMPORTANT: Provide only the letter corresponding to your chosen answer. Do not write out the full answer or give any explanation. + +====CONTEXT==== +{context} + +====QUESTION==== +{question} + +====ANSWER CHOICES==== +{choices}""" + + def __init__(self, prompt_template=DEFAULT_PROMPT_TEMPLATE, max_tokens=1, **kwargs): + super().__init__( + prompt_template, max_tokens, hf_args=["rbiswasfc/quality"], **kwargs + ) + + self.metrics = { + "Accuracy": AutoMetric.from_name("accuracy"), + } + self.test_split = "validation" # Test split doesn't have ground truths - use validation split + + self.mandatory_cols = self.mandatory_cols.copy() + ["num_choices"] + + def prepare_row(self, row: dict): + context = row["context"] + question = row["question"] + choices = row["choices"] + num_choices = len(choices) + answer = ascii_uppercase[row["label"]] + + choices = "\n".join( + [f"{char}. {opt}" for char, opt in zip(ascii_uppercase, choices)] + ) + + return { + "context": context, + "question": question, + "prompt": self.prompt_template.format( + context=context, question=question, choices=choices + ), + "labels": answer, + "num_choices": num_choices, + } + + def _process_logits(self, logits, split): + preds = [] + for l, nc in zip(logits, self.get_split(split)["num_choices"]): + pred = [l[ascii_uppercase[i]] for i in range(nc)] + preds.append(ascii_uppercase[np.argmax(pred)]) + + return preds + + +class RulerQA(EvaluationTask): + """ + RULER hotpotqa task with 8k context length. (context length can be adjusted as needed) + """ + + DEFAULT_PROMPT_TEMPLATE = "{task_input}" + + def __init__( + self, prompt_template=DEFAULT_PROMPT_TEMPLATE, max_tokens=32, **kwargs + ): + super().__init__( + prompt_template, + max_tokens, + hf_args=["rbiswasfc/ruler", "qa_2_8k"], + **kwargs, + ) + + self.metrics = { + "StringMatch": AutoMetric.from_name("ruler-string-match", match_part=True), + } + self.test_split = "validation" + + def prepare_row(self, row: dict): + task_input = row["input"] + + question = task_input.split("Question:")[-1].split("Answer:")[0].strip() + context = task_input.split("Question:")[0].strip() + + prompt = self.prompt_template.format(task_input=task_input) + answer = row["outputs"] # List[str] + + return { + "context": context, + "question": question, + "prompt": prompt, + "labels": answer, + } + + +class RulerNIAH(EvaluationTask): + """ + RULER Multi-keys Needle-in-a-haystack (NIAH) task with 8k context length. (context length can be adjusted as needed) + """ + + DEFAULT_PROMPT_TEMPLATE = "{task_input}" + + def __init__( + self, prompt_template=DEFAULT_PROMPT_TEMPLATE, max_tokens=128, **kwargs + ): + super().__init__( + prompt_template, + max_tokens, + hf_args=["rbiswasfc/ruler", "niah_multikey_1_8k"], + **kwargs, + ) + + self.metrics = { + "StringMatch": AutoMetric.from_name("ruler-string-match", match_part=False), + } + self.test_split = "validation" + + def prepare_row(self, row: dict): + task_input = row["input"] + + question = ( + "The special magic number for fair-sprout mentioned in the provided text is" + ) + context = task_input + + prompt = self.prompt_template.format(task_input=task_input) + answer = row["outputs"] # List[str] + + return { + "context": context, + "question": question, + "prompt": prompt, + "labels": answer, + } + + +class RulerVT(EvaluationTask): + """ + RULER Multi-hop Tracing: Variable Tracking (VT) task with 8k context length. (context length can be adjusted as needed) + """ + + DEFAULT_PROMPT_TEMPLATE = "{task_input}" + + def __init__( + self, prompt_template=DEFAULT_PROMPT_TEMPLATE, max_tokens=30, **kwargs + ): + super().__init__( + prompt_template, + max_tokens, + hf_args=["rbiswasfc/ruler", "vt_8k"], + **kwargs, + ) + + self.metrics = { + "StringMatch": AutoMetric.from_name("ruler-string-match", match_part=False), + } + self.test_split = "validation" + + def prepare_row(self, row: dict): + task_input = row["input"] + + question = task_input.split("Question:")[-1].split("Answer:")[0].strip() + context = task_input.split("Question:")[0].strip() + + prompt = self.prompt_template.format(task_input=task_input) + answer = row["outputs"] # List[str] + + return { + "context": context, + "question": question, + "prompt": prompt, + "labels": answer, + } + + +class RulerCWE(EvaluationTask): + """ + RULER Aggregation: Common Words (CWE) task with 8k context length. (context length can be adjusted as needed) + """ + + DEFAULT_PROMPT_TEMPLATE = "{task_input}" + + def __init__( + self, prompt_template=DEFAULT_PROMPT_TEMPLATE, max_tokens=120, **kwargs + ): + super().__init__( + prompt_template, + max_tokens, + hf_args=["rbiswasfc/ruler", "cwe_8k"], + **kwargs, + ) + + self.metrics = { + "StringMatch": AutoMetric.from_name("ruler-string-match", match_part=False), + } + self.test_split = "validation" + + def prepare_row(self, row: dict): + task_input = row["input"] + + question = task_input.split("Question:")[-1].split("Answer:")[0].strip() + context = task_input.split("Question:")[0].strip() + + prompt = self.prompt_template.format(task_input=task_input) + answer = row["outputs"] # List[str] + + return { + "context": context, + "question": question, + "prompt": prompt, + "labels": answer, + } + + +TASK_MAPPING = { + "squality": Squality, + "triviaqa": TriviaQA, + "dolomites": Dolomites, + "qmsum": QMSum, + "musique": Musique, + "truthfulqa": TruthfulQA, + "scrollsquality": ScrollsQuality, + "rulerqa": RulerQA, + "rulerniah": RulerNIAH, + "rulervt": RulerVT, + "rulercwe": RulerCWE, +} + + +class AutoTask: + def __init__(self): + raise EnvironmentError( + "This class is designed to be instantiated only through the from_name method" + ) + + def from_name(task_name, **kwargs): + if task_name not in TASK_MAPPING: + raise ValueError( + f"Task {task_name} not found. Available tasks: {TASK_MAPPING.keys()}" + ) + return TASK_MAPPING[task_name](**kwargs) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser("Test out implementation of EvaluationTask") + + parser.add_argument( + "--task", type=str, default="triviaqa", choices=TASK_MAPPING.keys() + ) + parser.add_argument("--compute_stats", action="store_true", default=False) + parser.add_argument("--num_samples", default=int(1e10), type=int) + + parser.add_argument( + "--checkpoint_path", + type=Path, + default=Path(__file__).resolve().parent + / "checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth", + help="Model checkpoint path.", + ) + + args = parser.parse_args() + + is_chat = ( + "chat" in str(args.checkpoint_path).lower() + or "instruct" in str(args.checkpoint_path).lower() + ) + + tokenizer_path = args.checkpoint_path.parent / "tokenizer.model" + if not tokenizer_path.is_file(): + # If there's no tokenizer.model, try to load the tokenizer from the parent directory + # NOTE: We assume the tokenizer in the parent directory is compatible with huggingface transformers + tokenizer_path = args.checkpoint_path.parent + + tokenizer = get_tokenizer(tokenizer_path, args.checkpoint_path, is_chat=is_chat) + + # Dummy values + task_kwargs = { + "model_max_length": int(1e10), + "num_samples": args.num_samples, + "tokenizer": tokenizer.encode_prompt if is_chat else tokenizer.encode, + } + + def num_toks(x): + return len(task_kwargs["tokenizer"](x)) + + if args.compute_stats: + stats = [] + for task_name in TASK_MAPPING.keys(): + print(f"Computing stats for {task_name}") + task = AutoTask.from_name(task_name, **task_kwargs) + test = task.get_test() + + prompts = test["prompt"] + labels = test["labels"] + + prompt_tokens = sum([num_toks(p) for p in test["prompt"]]) / len(test) + num_references = sum( + [1 if type(l) != list else len(l) for l in labels] + ) / len(test) + + avg_reference_len = [] + for l in labels: + if type(l) != list: + l = [l] + avg_reference_len.append(sum([num_toks(x) for x in l]) / len(l)) + avg_reference_len = sum(avg_reference_len) / len(avg_reference_len) + + avg_n_choices = ( + None + if "num_choices" not in test + else sum(test["num_choices"]) / len(test) + ) + + stats.append( + { + "task": task_name, + "n": len(test), + "is_mcqa": task.requires_logits, + "prompt_tokens": prompt_tokens, + "label_tokens": avg_reference_len, + "n_choices": avg_n_choices, + } + ) + + stats = pd.DataFrame(stats) + stats_fn = Path(__file__).parent / "cache_configs" / "task_stats.csv" + stats = stats.sort_values("task").reset_index(drop=True) + stats.to_csv(stats_fn, index=False) + else: + task = AutoTask.from_name(args.task, **task_kwargs) + test = task.get_test() + print("Example test datapoint:\n\n") + ex = test[0] + for k, v in ex.items(): + print(f"{k}:\n{v}\n\n") + + train_predictions = ["This is a train prediction"] * len(task.dataset["train"]) + test_predictions = ["This is a test prediction"] * len(test) + + print("A 'not ready' error should be displayed below:\n\n") + try: + task.train_metrics(predictions=train_predictions) + except Exception as e: + print(e) + + print("A 'length mismatch' error should be displayed below:\n\n") + try: + task.test_metrics(predictions=test_predictions[:-1]) + except Exception as e: + print(e) + + print("Dummy metrics for test split:\n\n") + print(task.test_metrics(predictions=test_predictions)) diff --git a/tokenizer.py b/tokenizer.py index c62a0c5..4f58bc7 100644 --- a/tokenizer.py +++ b/tokenizer.py @@ -1,32 +1,115 @@ +from abc import ABC, abstractmethod +import itertools import os +import regex as re +import string import sentencepiece as spm import tiktoken +import torch from tiktoken.load import load_tiktoken_bpe +from transformers import AutoTokenizer from pathlib import Path -from typing import Dict +from typing import ( + Dict, + List, + Literal, + TypedDict, +) -class TokenizerInterface: + +default_device = "cuda" if torch.cuda.is_available() else "cpu" + + +def is_punc_id(text): + # Define a regex pattern that matches any character that is not whitespace or punctuation + pattern = rf"^[\s{re.escape(string.punctuation)}]*$" + return bool(re.match(pattern, text)) + + +class TokenizerInterface(ABC): def __init__(self, model_path): self.model_path = model_path + self.vocab = None + @abstractmethod def encode(self, text): - raise NotImplementedError("This method should be overridden by subclasses.") + pass + @abstractmethod def decode(self, tokens): - raise NotImplementedError("This method should be overridden by subclasses.") + pass + @abstractmethod def bos_id(self): - raise NotImplementedError("This method should be overridden by subclasses.") + pass + @abstractmethod def eos_id(self): - raise NotImplementedError("This method should be overridden by subclasses.") + pass + + @abstractmethod + def get_terminator_ids(self): + pass + + @abstractmethod + def special_ids(self) -> List[List[int]]: + pass + + @abstractmethod + def __len__(self): + pass + + def punctuation_ids(self): + return [i for i, wp in enumerate(self.vocab) if is_punc_id(wp)] + + def get_vocab(self): + assert ( + self.vocab is not None + ), "Subclasses should set the vocab attribute during initialization." + return self.vocab + class SentencePieceWrapper(TokenizerInterface): def __init__(self, model_path): super().__init__(model_path) + self.model_path = model_path self.processor = spm.SentencePieceProcessor(str(model_path)) + self.terminator_ids = [self.processor.eos_id()] + self.vocab = [ + self.processor.id_to_piece(id) + for id in range(self.processor.get_piece_size()) + ] - def encode(self, text): + def addl_special_ids(self): + # If llama-2 in model path, return special tokens for llama-2 + if "llama-2" in str(self.model_path).lower(): + special_tokens = ["[INST]", "[/INST]"] + else: + raise ValueError(f"Unknown model path: {self.model_path}") + + def _encode_special(token): + ids = self.processor.EncodeAsIds(token) + if len(ids) > 1: + print(f"Special token {token} was tokenized into {len(ids)} tokens") + return ids + + return list(map(_encode_special, special_tokens)) + + def special_ids(self) -> List[List[int]]: + # Some of the chat templates aren't given a singular special token so we return a list of lists + return [ + [self.processor.bos_id()], + [self.processor.eos_id()], + *self.addl_special_ids(), + ] + + def encode(self, prompt): + if isinstance(prompt, dict): + text = prompt['instruction'] + if "input" in prompt: + text += "\n" + prompt['input'] + else: + text = prompt return self.processor.EncodeAsIds(text) def decode(self, tokens): @@ -38,7 +121,13 @@ def bos_id(self): def eos_id(self): return self.processor.eos_id() -class TiktokenWrapper(TokenizerInterface): + def get_terminator_ids(self): + return self.terminator_ids + + def __len__(self): + return self.processor.get_piece_size() + +class Llama3Wrapper(TokenizerInterface): """ Tokenizing and encoding/decoding text using the Tiktoken tokenizer. """ @@ -68,7 +157,7 @@ def __init__(self, model_path): ] + [ f"<|reserved_special_token_{i}|>" for i in range(5, self.num_reserved_special_tokens - 5) - ] + ] self.special_tokens = { token: num_base_tokens + i for i, token in enumerate(special_tokens) } @@ -81,10 +170,22 @@ def __init__(self, model_path): # BOS / EOS token IDs self._bos_id: int = self.special_tokens["<|begin_of_text|>"] self._eos_id: int = self.special_tokens["<|end_of_text|>"] + self.terminator_ids = [self._eos_id, self.special_tokens["<|eot_id|>"]] + self.vocab = [self.model.decode([i]) for i in range(self.model.n_vocab)] - def encode(self, text): + def encode(self, prompt): + if isinstance(prompt, dict): + text = prompt['instruction'] + if "input" in prompt: + text += "\n" + prompt['input'] + else: + text = prompt return self.model.encode(text) + def special_ids(self) -> List[List[int]]: + # Some of the chat templates aren't given a singular special token so we return a list of lists + return [[x] for x in list(sorted(self.special_tokens.values()))] + def decode(self, tokens): return self.model.decode(tokens) @@ -94,10 +195,163 @@ def bos_id(self): def eos_id(self): return self._eos_id -def get_tokenizer(tokenizer_model_path, model_name): + def get_terminator_ids(self): + return self.terminator_ids + + def __len__(self): + return self.model.n_vocab + +class Llama3GistWrapper(TokenizerInterface): """ - Factory function to get the appropriate tokenizer based on the model name. + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + """ + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 + + def __init__(self, model_path, gist_position="instruction_end"): + super().__init__(model_path) + assert os.path.isfile(model_path), str(model_path) + mergeable_ranks = load_tiktoken_bpe(str(model_path)) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # end of turn + ] + [ + f"<|reserved_special_token_{i}|>" + for i in range(5, self.num_reserved_special_tokens - 5) + ] + ["<|gist|>"] + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + # BOS / EOS token IDs + self._bos_id: int = self.special_tokens["<|begin_of_text|>"] + self._eos_id: int = self.special_tokens["<|end_of_text|>"] + self.terminator_ids = [self._eos_id, self.special_tokens["<|eot_id|>"]] + self.gist_id = self.special_tokens["<|gist|>"] + self.gist_position = gist_position + self.prompt_with_input = ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" + ) + self.prompt_without_input = ( + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Response:" + ) + self.vocab = [self.model.decode([i]) for i in range(self.model.n_vocab)] + + def encode_prompt(self, text): + return self.encode(text) + + def encode(self, prompt): + if isinstance(prompt, str): + prompt = {"instruction": prompt} + return self.encode(prompt) + prompt_template = self.prompt_with_input if 'input' in prompt else self.prompt_without_input + if "input" not in prompt: + prompt = prompt_template.format(instruction=prompt['instruction'] + "<|gist|>") + elif self.gist_position == "instruction": + prompt = prompt_template.format(instruction=prompt["instruction"] + "<|gist|>", input=prompt["input"]) + elif self.gist_position == "input": + prompt = prompt_template.format(instruction=prompt["instruction"], input=prompt["input"] + "<|gist|>") + elif self.gist_position == "instruction_and_input": + prompt = prompt_template.format(instruction=prompt["instruction"] + "<|gist|>", input=prompt["input"] + "<|gist|>") + return self.model.encode(prompt, allowed_special={'<|gist|>'}) + + def special_ids(self) -> List[List[int]]: + # Some of the chat templates aren't given a singular special token so we return a list of lists + return [[x] for x in list(sorted(self.special_tokens.values()))] + + def decode(self, tokens): + return self.model.decode(tokens) + + def bos_id(self): + return self._bos_id + + def eos_id(self): + return self._eos_id + + def get_terminator_ids(self): + return self.terminator_ids + + def gist_token_id(self): + return self.gist_id + def __len__(self): + return len(self.model.n_vocab) + + +class TokenizersWrapper(TokenizerInterface): + def __init__(self, model_path): + super().__init__(model_path) + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.terminator_ids = [self.tokenizer.eos_token_id] + self.vocab = [ + self.tokenizer.decode(i) for i in range(self.tokenizer.vocab_size) + ] + + def special_ids(self) -> List[List[int]]: + if hasattr(self.tokenizer, "special_token_ids"): + return [[x] for x in self.tokenizer.special_token_ids] + + # Its likely a tokenizer that has a special_tokens_map attribute + special_tokens_ = list(self.tokenizer.special_tokens_map.values()) + special_tokens = [] + for t in special_tokens_: + if type(t) == list: + special_tokens.extend(t) + else: + special_tokens.append(t) + special_tokens = list(set(special_tokens)) + return [[self.tokenizer.convert_tokens_to_ids(t)] for t in special_tokens] + + def encode(self, prompt): + if isinstance(prompt, dict): + text = prompt['instruction'] + if "input" in prompt: + text += "\n" + prompt['input'] + else: + text = prompt + return self.model.encode(text) + + def decode(self, tokens): + return self.tokenizer.decode(tokens) + + def bos_id(self): + return self.tokenizer.bos_token_id + + def eos_id(self): + return self.tokenizer.eos_token_id + + def get_terminator_ids(self): + return self.terminator_ids + + def __len__(self): + return len(self.tokenizer) + + +def get_tokenizer(tokenizer_model_path, model_name, is_chat=False): + """ + Factory function to get the appropriate tokenizer based on the model name. + Args: - tokenizer_model_path (str): The file path to the tokenizer model. - model_name (str): The name of the model, used to determine the tokenizer type. @@ -105,7 +359,127 @@ def get_tokenizer(tokenizer_model_path, model_name): Returns: - TokenizerInterface: An instance of a tokenizer. """ - if "Llama-3" in str(model_name): - return TiktokenWrapper(tokenizer_model_path) + model_name = str(model_name).lower() + if "gist" in model_name: + if "instruction-only" in model_name: + return Llama3GistWrapper(tokenizer_model_path, gist_position="instruction") + elif "input-only" in model_name: + return Llama3GistWrapper(tokenizer_model_path, gist_position="input") + elif "instruction-and-input" in model_name: + return Llama3GistWrapper(tokenizer_model_path, gist_position="instruction_and_input") + else: + raise ValueError(f"Invalid gist model name: {model_name}") + if "llama-3" in model_name: + return ( + Llama3ChatFormat(tokenizer_model_path) + if is_chat + else Llama3Wrapper(tokenizer_model_path) + ) + elif "llama-2" in str(model_name).lower(): + return ( + Llama2ChatFormat(tokenizer_model_path) + if is_chat + else SentencePieceWrapper(tokenizer_model_path) + ) else: - return SentencePieceWrapper(tokenizer_model_path) + return ( + TokenizersChatFormat(tokenizer_model_path) + if is_chat + else TokenizersWrapper(tokenizer_model_path) + ) + + +Role = Literal["system", "user", "assistant"] + + +class Message(TypedDict): + role: Role + content: str + + +class Llama3ChatFormat(Llama3Wrapper): + def __init__(self, model_path): + super().__init__(model_path) + + def encode_header(self, message: Message) -> List[int]: + return [ + self.special_tokens["<|start_header_id|>"], + *self.encode(message["role"]), + self.special_tokens["<|end_header_id|>"], + *self.encode("\n\n"), + ] + + def encode_prompt(self, prompt): + if isinstance(prompt, dict): + text = prompt['instruction'] + if "input" in prompt: + text += "\n" + prompt['input'] + return self.encode_dialog_prompt([{"role": "user", "content": text}]) + + def encode_message(self, message: Message) -> List[int]: + tokens = self.encode_header(message) + tokens.extend(self.encode(message["content"].strip())) + tokens.append(self.special_tokens["<|eot_id|>"]) + return tokens + + def encode_dialog_prompt(self, dialog: List[Message]) -> List[int]: + return [ + self.special_tokens["<|begin_of_text|>"], + *list(itertools.chain(*map(self.encode_message, dialog))), + # Add the start of an assistant message for the model to complete. + *self.encode_header({"role": "assistant", "content": ""}), + ] + + +class Llama2ChatFormat(SentencePieceWrapper): + B_INST = "[INST]" + E_INST = "[/INST]" + + def __init__(self, model_path): + super().__init__(model_path) + + def encode_prompt(self, prompt): + if isinstance(prompt, dict): + text = prompt['instruction'] + if "input" in prompt: + text += prompt['input'] + ids = [self.bos_id()] + ids += self.encode(Llama2ChatFormat.B_INST + "\n\n") + ids += self.encode(prompt + " " + Llama2ChatFormat.E_INST) + return ids + + +class TokenizersChatFormat(TokenizersWrapper): + def __init__(self, model_path): + super().__init__(model_path) + + def encode_prompt(self, prompt: str): + if isinstance(prompt, dict): + text = prompt['instruction'] + if "input" in prompt: + text += "\n" + prompt['input'] + messages = [{"role": "user", "content": prompt}] + return self.encode_dialog_prompt(messages) + + def encode_dialog_prompt(self, dialog: List[Message]) -> List[int]: + text = self.tokenizer.apply_chat_template( + dialog, tokenize=False, add_generation_prompt=True + ) + return self.encode(text) + + +def encode_tokens(tokenizer, string, bos=True, device=default_device): + tokens = tokenizer.encode(string) + if bos: + tokens = [tokenizer.bos_id()] + tokens + return torch.tensor(tokens, dtype=torch.int, device=device) + + +def encode(tokenizer, prompt, device=default_device, is_chat=True): + if is_chat: + tokens = tokenizer.encode_prompt(prompt) + encoded = torch.tensor(tokens, dtype=torch.int, device=device) + else: + encoded = encode_tokens(tokenizer, prompt, device=device, bos=True) + + return encoded diff --git a/tp.py b/tp.py index a151a87..bf0751d 100644 --- a/tp.py +++ b/tp.py @@ -9,6 +9,7 @@ import torch import torch.distributed as dist from torch import nn + if os.uname().sysname != "Darwin": from torch.distributed import _functional_collectives as funcol else: @@ -22,17 +23,21 @@ def _get_rank() -> int: return int(os.environ.get("LOCAL_RANK", "0")) + def is_local(): return _get_rank() == 0 + def local_break(): if is_local(): breakpoint() dist.barrier() + def _get_world_size() -> int: return int(os.environ.get("LOCAL_WORLD_SIZE", "1")) + def maybe_init_dist() -> Optional[int]: try: # provided by torchrun @@ -51,21 +56,21 @@ def maybe_init_dist() -> Optional[int]: return rank -def _apply_tp_linear(linear: nn.Linear, style: str, weight_splits: List[int] = []) -> None: +def _apply_tp_linear( + linear: nn.Linear, style: str, weight_splits: List[int] = [] +) -> None: rank = _get_rank() world_size = _get_world_size() # Linear's weight matrix is transposed, and is of shape # (linear.out_features, linear.in_features) - dim_lookup = { - "colwise": (0, "out_features"), - "rowwise": (1, "in_features") - } + dim_lookup = {"colwise": (0, "out_features"), "rowwise": (1, "in_features")} assert style in dim_lookup shard_dim, size_attr = dim_lookup[style] # ensure we can shard evenly assert getattr(linear, size_attr) % world_size == 0 + def shard(x, dim): assert x.size(dim=dim) % world_size == 0 return torch.tensor_split(x, world_size, dim=dim)[rank] @@ -75,7 +80,7 @@ def shard_qkv(qkv, dim, weight_splits): q = shard(q, dim) k = shard(k, dim) v = shard(v, dim) - return torch.cat((q,k,v), dim=dim) + return torch.cat((q, k, v), dim=dim) # shard if weight_splits: @@ -83,8 +88,12 @@ def shard_qkv(qkv, dim, weight_splits): assert len(weight_splits) == 3 if isinstance(linear, WeightOnlyInt4Linear): - sharded_weight = shard_qkv(linear.weight, shard_dim, [i//8 for i in weight_splits]) - linear.scales_and_zeros = shard_qkv(linear.scales_and_zeros, 1 - shard_dim, weight_splits) + sharded_weight = shard_qkv( + linear.weight, shard_dim, [i // 8 for i in weight_splits] + ) + linear.scales_and_zeros = shard_qkv( + linear.scales_and_zeros, 1 - shard_dim, weight_splits + ) else: sharded_weight = shard_qkv(linear.weight, shard_dim, weight_splits) if hasattr(linear, "scales") and style == "colwise": @@ -94,7 +103,12 @@ def shard_qkv(qkv, dim, weight_splits): if isinstance(linear, WeightOnlyInt4Linear): linear.scales_and_zeros = shard(linear.scales_and_zeros, 1 - shard_dim) if style == "rowwise": - assert linear.scales_and_zeros.shape[0] * 32 == sharded_weight.shape[1] * sharded_weight.shape[2] * sharded_weight.shape[3] + assert ( + linear.scales_and_zeros.shape[0] * 32 + == sharded_weight.shape[1] + * sharded_weight.shape[2] + * sharded_weight.shape[3] + ) assert linear.scales_and_zeros.shape[1] == sharded_weight.shape[0] * 8 if hasattr(linear, "scales") and style == "colwise": linear.scales = shard(linear.scales, 0) @@ -117,8 +131,11 @@ def _apply_tp_ffn(mlp: FeedForward) -> None: _apply_tp_linear(mlp.w2, "rowwise") world_size = _get_world_size() - mlp.register_forward_hook(lambda _module, _input, output: funcol.all_reduce( - output, "sum", list(range(world_size)))) + mlp.register_forward_hook( + lambda _module, _input, output: funcol.all_reduce( + output, "sum", list(range(world_size)) + ) + ) def _apply_tp_attn(attn: Attention) -> None: @@ -136,8 +153,11 @@ def _apply_tp_attn(attn: Attention) -> None: attn.head_dim = attn.dim // attn.n_head attn.n_local_heads = attn.n_local_heads // world_size - attn.register_forward_hook(lambda _module, _input, output: funcol.all_reduce( - output[0], "sum", list(range(world_size)))) + attn.register_forward_hook( + lambda _module, _input, output: funcol.all_reduce( + output[0], "sum", list(range(world_size)) + ) + ) def _apply_tp_Transformer(Transformer: Transformer) -> None: