Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
b5904ff
Prepare Llama-3-8B-Instruct.
May 20, 2024
5b2eb6a
Start prompt summarization code. Partial update since need to restart…
May 21, 2024
d3811ec
more
May 22, 2024
c2af69b
Reformatting with ruff.
May 24, 2024
06aa103
Adds fixed-window KV-cache.
May 24, 2024
a4dd428
Merge pull request #9 from AnswerDotAI/window
Jun 4, 2024
0f7192d
Update sdpa to allow for return_attn=True flag. Enables work on Heavy…
Jun 4, 2024
2e983d8
Layer specific max_cache_length. Switch to a list arg.
Jun 4, 2024
5f04c78
Moving cache prefill logic from attention module's forward() to cache…
Jun 4, 2024
5324011
Apply ruff formatting, minor cosmetic changes.
Jun 5, 2024
b933300
Add default prompts into ./prompts.
Jun 6, 2024
84dff2d
Make Llama-2 chat the default and support its chat template.
Jun 10, 2024
5bbdcf1
Implement Scissorhands paper as KVCacheScissorHands.
Jun 4, 2024
45e9854
Implements prompt compression with SnapKV and Remove Middle (keep win…
Jun 6, 2024
0669421
Merge pull request #11 from AnswerDotAI/snap
Jun 11, 2024
54c7212
Add KVCacheRandom as lowerbound baseline.
Jun 11, 2024
6f1cf31
Fix pre-existing bug in max_new_tokens. Removes T_new.
Jun 11, 2024
759a6d9
dolomites compatibility
sarahpannn Jun 12, 2024
f7e71bb
dtype auto
sarahpannn Jun 12, 2024
59f2416
Minor change to max_cache_length assertion equality statement.
Jun 12, 2024
07e4a4b
Ruff formatting and dont err out for max cache length longer than max…
Jun 13, 2024
6892ea1
Squashed commit of the following:
Jun 13, 2024
e7a2a91
Fix various bugs
VikParuchuri Jun 14, 2024
d4b4154
Merge pull request #14 from AnswerDotAI/vik/optims
VikParuchuri Jun 14, 2024
df8da8d
Fix minor Llama-3 chat template bug and apply ruff formatting.
Jun 18, 2024
0d22ad8
Comment out assertion.
Jun 18, 2024
a074f75
Adds code for evaluation, along with Squality and reference-based met…
Jun 17, 2024
1b5bda7
Merge pull request #16 from AnswerDotAI/evals
Jun 18, 2024
a4ecbba
Update Tasks to allow for train, val, test splits to be used.
Jun 18, 2024
59d133c
Merge pull request #17 from AnswerDotAI/eval2
Jun 18, 2024
3813438
Add TriviaQA.
Jun 18, 2024
292f7a9
Remove unused data_utils.py.
Jun 18, 2024
ec5b7f8
Update README.md to turn repo public.
Jun 18, 2024
7c1d928
Added Dolomites and QMSum datasets.
Jun 19, 2024
c4dc693
Do not recompute cache_kwargs 'drop_amount'. Fixes issue with torch.c…
Jun 19, 2024
56f9174
Merge pull request #18 from AnswerDotAI/evals
fladhak Jun 19, 2024
9ff5e8d
Minor change to requirements.txt
Jun 19, 2024
750e632
Added support for MuSiQue dataset, along with a small bug fix for gen…
Jun 19, 2024
8a4ae32
Merge pull request #19 from AnswerDotAI/evals
fladhak Jun 20, 2024
1b9487e
Add L2 Norm Cache and refactor prompt compression to its own file.
Jun 20, 2024
8cf2669
Merge pull request #22 from AnswerDotAI/l2
Jun 20, 2024
176d589
Minor bugfix. Remove self.updates variable since its unused.
Jun 21, 2024
e7baa14
Refactor, add L2 prompt compression, and fix important compute attent…
Jun 24, 2024
e3e1b25
Squashed commit of the following:
Jun 24, 2024
11db56e
Update generate.py to pull from generation_utils.py
Jun 25, 2024
08e374b
added quality
rbiswasfc Jun 25, 2024
3228cd2
Move cache_kwargs to cache.py.
Jun 25, 2024
037205c
Merge pull request #26 from AnswerDotAI/rb/scrolls-quality
Jun 25, 2024
2fd4341
Major refactor to simplify code and move cache_kwargs to cache.py.
Jun 25, 2024
f5cf935
Implement FastGen in a naive way with mostly sparse attention masks.
Jun 24, 2024
9180e3e
Merge pull request #25 from AnswerDotAI/profile
Jun 25, 2024
ca9d282
Ruff formatting.
Jun 25, 2024
76c2bce
Minor bugfixes.
Jun 25, 2024
5ed6e56
Add --debug mode for eval.py.
Jun 25, 2024
ff514c1
Ruff formatting.
Jun 25, 2024
eeaf233
Removed redundant method for getting wordpieces.
Jun 26, 2024
6379081
Add cache stats.
Jun 26, 2024
c198053
Minor change to accounting to count generation of terminator_id as pa…
Jun 26, 2024
cc6729f
Update stats recording.
Jun 26, 2024
0eb917e
Standardize prompt.
Jun 26, 2024
97f76a8
Filter dataset to remove examples with prompts larger than max length…
Jun 26, 2024
188183d
Save evals, greedy decoding.
Jun 26, 2024
8985880
Added logic to subsample dataset.
Jun 26, 2024
268d74e
Add LLM evals LLM-Rouge and LLM-Judge with Claudette.
Jun 26, 2024
d9cf5ab
adding ruler/qa2 4k
rbiswasfc Jun 27, 2024
2dcb22a
Fixed minor bug that was skipping eval for same task with different c…
Jun 27, 2024
7eaf591
Merge branch 'main' of https://github.com/AnswerDotAI/context-compres…
rbiswasfc Jun 27, 2024
17a6f66
added 4 tasks from ruler
rbiswasfc Jun 27, 2024
cc8c039
ruff
rbiswasfc Jun 27, 2024
6a02bef
Merge pull request #28 from AnswerDotAI/rb/ruler
Jun 27, 2024
1b76bc3
Filter based on tokenized length
Jun 27, 2024
f7b5316
Add task stats and fix bug in task.py where self.mandatory_cols was b…
Jun 27, 2024
269439e
Update metric.
Jun 27, 2024
8014804
Minor bugfixes.
Jun 27, 2024
a5f3e71
Bugfix for LLM metrics.
Jun 28, 2024
0a44ff7
Add --tasks all option for bulk eval.
Jul 1, 2024
19594b5
Minor change to save path to include model name.
Jul 1, 2024
0ee7c30
Add eval_multi.py for running hparam sweep evals.
Jul 1, 2024
c2896e7
Switch from 4k to 8k RULER tasks.
Jul 1, 2024
267e9f5
Add 4k as a cache length.
Jul 1, 2024
c4c87b5
Add 8k.
Jul 1, 2024
0f719ca
Changed eval order for hyperparm
Jul 1, 2024
c608e80
Remove default attention thresholding.
Jul 1, 2024
127773a
Add random prompt compression strategy.
Jul 2, 2024
12a6435
Update default configs.
Jul 2, 2024
9c9845b
Changes to FastGen.
Jul 3, 2024
008175b
Add KVCacheAnalysis which computes attention loss.
Jul 3, 2024
12be67c
Update FastGen to use new attention loss calculation.
Jul 3, 2024
9056c97
add gist model generation utils to library
uSaiPrashanth Jul 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,9 @@ wandb

# downloaded by our tests
original_model.py
original_adapter.py
original_adapter.py

.vscode

torch_compile_debug
results
32 changes: 19 additions & 13 deletions GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from eval import (
setup_cache_padded_seq_input_pos_max_seq_length_for_prefill,
GPTFastEvalWrapper
GPTFastEvalWrapper,
)


Expand Down Expand Up @@ -63,7 +63,6 @@ def __init__(
)
self.pad_calibration_inputs = False


def add_input(self, args):
if self.inputs is None:
self.inputs = [MultiInput([arg]) for arg in args]
Expand Down Expand Up @@ -113,7 +112,6 @@ def _model_call(self, inps):
)



class MultiInput:
def __init__(self, inputs):
self.values = list(inputs)
Expand All @@ -126,7 +124,9 @@ def __getitem__(self, slice):
return MultiInput(self.values[slice])

def cuda(self):
self.values = [val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values]
self.values = [
val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values
]


class GenericGPTQRunner(fx.Interpreter):
Expand Down Expand Up @@ -235,7 +235,12 @@ def tensors_to_cuda(args):
)
transposed_args = list(
zip(
*[x.values if isinstance(x, MultiInput) else [x] * multi_input_count for x in flat_args]
*[
x.values
if isinstance(x, MultiInput)
else [x] * multi_input_count
for x in flat_args
]
)
)
else:
Expand All @@ -244,8 +249,8 @@ def tensors_to_cuda(args):

# check whether we apply GPTQ to this module
quantize_linear = (
(target == aten.linear.default) # if its a linear
and id(args[1]) in self.id_to_name # and if we know the layer name
(target == aten.linear.default) # if its a linear
and id(args[1]) in self.id_to_name # and if we know the layer name
and not skip_quant # and if we weren't told to skip quantization
# and if the skip_layer_func doesn't say we should skip
and not (self.skip_layer_func is not None and self.skip_layer_func(args[1]))
Expand All @@ -259,9 +264,7 @@ def tensors_to_cuda(args):
inp = tensors_to_cuda(inp)
cur_args, cur_kwargs = tree_unflatten(inp, spec)

if (
quantize_linear
): # calculate H instead of output (will run the linear eventually with updated weight)
if quantize_linear: # calculate H instead of output (will run the linear eventually with updated weight)
x = cur_args[0].float()
shape = x.shape
n = 1 if len(shape) == 2 else shape[0]
Expand Down Expand Up @@ -333,11 +336,14 @@ def SQNR(x, y):
target, (args[0][:2], DQ2, *args[2:]), kwargs, skip_quant=True
)

print("SQNR for output without GPTQ (should be less than above)",
torch.cat([
print(
"SQNR for output without GPTQ (should be less than above)",
torch.cat(
[
SQNR(old.cpu(), old_q.cpu()).unsqueeze(0)
for (old, old_q) in zip(old_out.values, old_q_out.values)
]).mean(),
]
).mean(),
)
return new_out

Expand Down
204 changes: 13 additions & 191 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,207 +1,29 @@
# gpt-fast
Simple and efficient pytorch-native transformer text generation.
# Fast-Compress

Featuring:
1. Very low latency
2. <1000 lines of python
3. No dependencies other than PyTorch and sentencepiece
4. int8/int4 quantization
5. Speculative decoding
6. Tensor parallelism
7. Supports Nvidia and AMD GPUs
**This a WIP - do not use unless you are interested in contributing to the ongoing project.**

This is *NOT* intended to be a "framework" or "library" - it is intended to show off what kind of performance you can get with native PyTorch :) Please copy-paste and fork as you desire.
This repo extends [GPT-Fast](https://github.com/pytorch-labs/gpt-fast) by adding SOTA KV Cache compression methods.

For an in-depth walkthrough of what's in this codebase, see this [blog post](https://pytorch.org/blog/accelerating-generative-ai-2/).

## Examples
In the spirit of keeping the repo minimal, here are various examples of extensions you can make to gpt-fast as PRs.
- [Gemma support](https://github.com/pytorch-labs/gpt-fast/pull/115)
## Supported Models

### LLaMA family
Please check the rest of this page about benchmark of LLaMA family models.

### Mixtral 8x7B
We also supported [Mixtral 8x7B](https://mistral.ai/news/mixtral-of-experts/) which is a high-quality sparse mixture of experts (MoE) model, the average token generation rates are:

| | 1 GPU | 2 GPU | 4 GPU | 8 GPU |
|------------------|---------|-----------|--------|------------|
|baseline(bfloat16)| OOM | 96.67 | 155.35 | 227.82 |
| int8 | 97.92 | 155.03 | 216.87 | 279.35 |

Note that the benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens).

For more details about Mixtral 8x7B, please check [this page](./mixtral-moe) or this [note](https://thonking.substack.com/p/short-supporting-mixtral-in-gpt-fast).

## Community

Projects inspired by gpt-fast in the community:

- [gpt-blazing](https://github.com/armed-gpt/gpt-blazing): applies the same performance optimization strategy to more models (e.g., baichuan2).
- [gptfast](https://github.com/MDK8888/GPTFast): applies a subset of the performance optimizations to all Huggingface models
- [gpt-accelera](https://github.com/Edward-Sun/gpt-accelera): extends `gpt-fast` to SFT/RM/PPO training and batched inference to optimize the throughput
When done, it *will* serve as an open-source, hackable toolkit to accelerate research onto memory efficient inference.

## Installation
[Download PyTorch nightly](https://pytorch.org/get-started/locally/)
Install sentencepiece and huggingface_hub
```bash
pip install sentencepiece huggingface_hub
pip install packaging ninja
MAX_JOBS=8 pip install flash-attn --no-build-isolation # Set MAX_JOBS to a lower value if you get OOM errors.
pip install -r requirements.txt
```

To download llama models, go to https://huggingface.co/meta-llama/Llama-2-7b and go through steps to obtain access.
Then login with `huggingface-cli login`

After logging in with `huggingface-cli login`, run


## Downloading Weights
Models tested/supported
```text
tinyllamas/stories{15,42,100}
openlm-research/open_llama_7b
meta-llama/Llama-2-7b-chat-hf
meta-llama/Llama-2-13b-chat-hf
meta-llama/Llama-2-70b-chat-hf
codellama/CodeLlama-7b-Python-hf
codellama/CodeLlama-34b-Python-hf
mistralai/Mistral-7B-v0.1
mistralai/Mistral-7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
```

For example, to convert Llama-2-7b-chat-hf
```bash
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
./scripts/prepare.sh $MODEL_REPO
bash scripts/prepare_llama3.sh
```

## Benchmarks
Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens).

| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) |
| -------- | ------- | ------ | ------ |
| Llama-2-7B | Base | 104.9 | 1397.31 |
| | 8-bit | 155.58 | 1069.20 |
| | 4-bit (G=32) | 196.80 | 862.69 |
| Llama-2-70B | Base | OOM ||
| | 8-bit | 19.13 | 1322.58 |
| | 4-bit (G=32) | 25.25 | 1097.66 |

### Speculative Sampling
[Verifier: Llama-70B (int4), Draft: Llama-7B (int4)](./scripts/speculate_70B_int4.sh): 48.4 tok/s
This will create necessary model and tokenizer files for`Meta-Llama-3-8B-Instruct` within `./checkpoints`. It will also create a smaller model for debugging purposes only, called `Meta-Llama-3-8B-Instruct-4-Layers`. This model removes all layers except for the first 4. It's quicker to load but will generate nonsense, so only use for debugging.

### Tensor Parallelism
| Model | Number of GPUs | Tokens/Second | Memory Bandwidth (GB/s) |
| -------- | ------- | ------ | ------ |
| Llama-2-7B | 1 | 104.9 | 1397.31 |
| | 2 | 168.84 | 1181.99 |
| | 4 | 254.02 | 955.83 |
| | 8 | 328.43 | 704.10 |
| Llama-2-70B | 1 | OOM | |
| | 2 | 21.32 | 1481.87 |
| | 4 | 38.01 | 1340.76 |
| | 8 | 62.50 | 1135.29 |
## Usage

### Tensor Parallelism + Quantization
| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) |
| -------- | ------- | ------ | ------ |
| Llama-2-70B | Base | 62.50 | 1135.29 |
| | 8-bit | 80.44 | 752.04 |
| | 4-bit (G=32) | 90.77 | 548.10 |

### AMD
Benchmarks run on one GCD of a MI-250x.

| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) |
| -------- | ------- | ------ | ------ |
| Llama-2-7B | Base | 76.33 | 1028.70 |
| | 8-bit | 101.86 | 700.06 |

## Generate Text

Model definition in `model.py`, generation code in `generate.py`.

```bash
python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --prompt "Hello, my name is"
```

To squeeze out a little bit more performance, you can also compile the prefill with `--compile_prefill`. This will increase compilation times though.

## Quantization
Choose device to use by
```bash
# The current support devices: cuda, cpu
export DEVICE=cuda
```
### Int8 Weight-Only Quantization
To generate this version of the model
```bash
# Spits out model at checkpoints/$MODEL_REPO/model_int8.pth
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int8
```
To run with int8, just pass the int8 checkpoint to generate.py.
```bash
python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth --device $DEVICE
```

### Int4 Weight-Only Quantization
To generate int4 version of model
```bash
# Spits out model at checkpoints/$MODEL_REPO/model_int4.g32.$DEVICE.pth
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 --groupsize 32
```

To run with int4, just pass the int4 checkpoint to generate.py.
```bash
python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --compile
```

## Speculative Sampling
To generate with speculative sampling (DRAFT_MODEL_REPO should point to a smaller model compared with MODEL_REPO).

In this example, the "smaller" model is just the int8 quantized version of the model.
```
export DRAFT_MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --draft_checkpoint_path checkpoints/$DRAFT_MODEL_REPO/model_int8.pth
```

Note: Running on an A100 80GB, albeit power-limited to 330 watts. Empirically, seems like peak bandwidth is about 1700 GB/s.


## Tensor Parallelism
```bash
ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=2 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth
```

## Experimental
### Evaluation
We use the EleutherAI evaluation harness to evaluate our model accuracy. To evaluate the accuracy, make sure the evaluation harness is installed and pass your model checkpoint and desired tasks to eval.py.

```bash
python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --compile --tasks hellaswag winogrande
```

Note: Generative tasks are currently not supported for gpt-fast

Installation Instructions for the evaluation harness: https://github.com/EleutherAI/lm-evaluation-harness/tree/master#install

### GPTQ
We have a pure pytorch implementation of GPTQ that utilizes torch._dynamo.export to access the model structure. You can generate a GPTQ quantized
version of int4 quantization by using the same command to quantize it but adding 'gptq' to the quantization mode i.e.
```bash
# Spits out model at checkpoints/$MODEL_REPO/model_int4-gptq.g32.pth
python quantize.py --mode int4-gptq --calibration_tasks wikitext --calibration_seq_length 2048
```

You can then eval or generate text with this model in the same way as above.

## License

`gpt-fast` is released under the [BSD 3](https://github.com/pytorch-labs/gpt-fast/main/LICENSE) license.

## Acknowledgements
Thanks to:
* Lightning AI for supporting pytorch and work in flash attention, int8 quantization, and LoRA fine-tuning.
* GGML for driving forward fast, on device inference of LLMs
* Karpathy for spearheading simple, interpretable and fast LLM implementations
* MLC-LLM for pushing 4-bit quantization performance on heterogeneous hardware
python generate.py --compile --cache_strategy full --prompt "short_prompt_long_output.txt"
```
45 changes: 45 additions & 0 deletions attention_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import math
from typing import Tuple

import torch
from torch.nn import functional as F


def scaled_dot_product_attention(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
scale=None,
return_attn=False,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor | None]:
"""
Uses naive PyTorch sdpa implementation if we need to return_attn. Otherwise use the optimized version.

The naive implementation will be optimized later.
"""
if not return_attn:
return F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attn_mask,
dropout_p=dropout_p,
scale=scale,
), None
B, H, L, S = query.size(0), query.size(1), query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_weight = query @ key.transpose(-2, -1) * scale_factor

if attn_mask is not None:
attn_bias = torch.zeros(B, H, L, S, dtype=query.dtype, device=query.device)
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
attn_weight += attn_bias

# TODO if returning attn_weight, should we just modify the attn_weight tensor to be attn_prob?
attn_prob = torch.softmax(attn_weight, dim=-1)
attn_prob = torch.dropout(attn_prob, dropout_p, train=True)
return_logits = kwargs.get("return_attn_logits", False)
return attn_prob @ value, attn_weight if return_logits else attn_prob
Loading