Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[flake8]
exclude = .venv,venv,.git,__pycache__,build,dist, .mypy_cache
exclude = .venv,venv,.git,__pycache__,build,dist,.mypy_cache,.pytest_cache
max-line-length = 120
per-file-ignores =
__init__.py:F401
Expand Down
13 changes: 7 additions & 6 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ jobs:
uses: actions/setup-python@v3
with:
python-version: 3.10.11
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
run: python -m build

- name: Install uv
uses: astral-sh/setup-uv@v6

- name: Build
run: uv build --no-sources

- name: Publish package
uses: pypa/gh-action-pypi-publish@release/v1
with:
Expand Down
10 changes: 5 additions & 5 deletions .github/workflows/style.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ jobs:
with:
python-version: 3.10.11

- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python3 -
echo "$HOME/.local/bin" >> $GITHUB_PATH # Add Poetry to the PATH
- name: Install uv
uses: astral-sh/setup-uv@v6
with:
enable-cache: true

- name: Install dependencies
run: poetry install --with dev
run: uv sync --all-groups

- name: Run style checks
run: make style
12 changes: 6 additions & 6 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Test

on:
push:
branches: [ main ]
branches: [ main ]
pull_request:

jobs:
Expand All @@ -15,12 +15,12 @@ jobs:
with:
python-version: 3.10.11

- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python3 -
echo "$HOME/.local/bin" >> $GITHUB_PATH # Add Poetry to the PATH
- name: Install uv
uses: astral-sh/setup-uv@v6
with:
enable-cache: true

- name: Install dependencies
run: poetry install --with dev
run: uv sync --all-groups

- run: make test
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ dev_notebooks/
results/
reports/
.DS_Store
poetry.lock
uv.lock
*.parquet
# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
12 changes: 6 additions & 6 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
SHELL := /bin/bash
POETRY ?= $(shell which poetry)
UV ?= $(shell which uv)
BUILD_VERSION:=$(APP_VERSION)
TESTS_FILTER:=

PYTEST_LOG=--log-cli-level=debug --log-format="%(asctime)s %(levelname)s [%(name)s:%(filename)s:%(lineno)d] %(message)s" --log-date-format="%Y-%m-%d %H:%M:%S"

.PHONY: isort
isort:
$(POETRY) run isort .
$(UV) run isort .

.PHONY: black
black:
$(POETRY) run black .
$(UV) run black .

PHONY: format
format: isort black
Expand All @@ -24,10 +24,10 @@ style: reports
@echo -n > reports/copyright_errors.log
@echo

-$(POETRY) run flake8 | tee -a reports/flake8_errors.log
-$(UV) run flake8 | tee -a reports/flake8_errors.log
@if [ -s reports/flake8_errors.log ]; then exit 1; fi

-$(POETRY) run mypy . --check-untyped-defs | tee -a reports/mypy.log
-$(UV) run mypy . --check-untyped-defs | tee -a reports/mypy.log
@if ! grep -Eq "Success: no issues found in [0-9]+ source files" reports/mypy.log ; then exit 1; fi

@echo "Checking for SPDX-FileCopyrightText headers in Python files..."
Expand All @@ -42,7 +42,7 @@ reports:
.PHONY: test
test: reports
PYTHONPATH=. \
$(POETRY) run pytest \
$(UV) run pytest \
--cov-report xml:reports/coverage.xml \
--cov=kvpress/ \
--junitxml=./reports/junit.xml \
Expand Down
34 changes: 25 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,17 @@ Deploying long-context LLMs is costly due to the linear growth of the key-value
pip install kvpress
```

If possible, install flash attention:
```bash
pip install flash-attn --no-build-isolation
```

For a local installation with all dev dependencies, use poetry:
For a local installation with all dev dependencies, use uv:

```bash
git clone https://github.com/NVIDIA/kvpress.git
cd kvpress
poetry install --with dev
uv sync --all-groups
```

## Usage

kvpress provides a set of "presses" that compress the KV cache during the prefilling-phase. Each press is associated with a `compression_ratio` attribute that measures the compression of the cache. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline`. It is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported and handles chat templates and tokenization for you:
KVPress provides a set of "presses" that compress the KV cache during the prefilling-phase. Each press is associated with a `compression_ratio` attribute that measures the compression of the cache. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline`. It is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported and handles chat templates and tokenization for you:

```python
from transformers import pipeline
Expand Down Expand Up @@ -208,4 +203,25 @@ with press(model):

However, the `generate` method does not allow to exclude the question from the compression, which would artificially favors methods such as SnapKV. Ideally, we want a compression method that works whatever comes after the context (_e.g._ for use cases such as chat or document question answering). Finally the `generate` method does not allow to provide generation for multiple questions at once.

</details>
</details>


## Advanced installation settings
To install optional packages, you can use [uv](https://docs.astral.sh/uv/).
To install with flash attention, just run:

```bash
git clone https://github.com/NVIDIA/kvpress.git
cd kvpress
uv sync --extra flash-attn
```

To install with dependencies for evaluation, run

```bash
git clone https://github.com/NVIDIA/kvpress.git
cd kvpress
uv sync --extra eval
```

Notice that optional dependecies can be combined.
1 change: 1 addition & 0 deletions evaluation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
We support evaluation for all the presses implemented in the library, on a variety of popular benchmarks.

### Quick Start 🚀
> Evaluation requires some additional packages. You can install them with `uv sync --group eval`

Running evaluation is straightforward! Make sure you are in the `evaluation` directory, then:

Expand Down
5 changes: 4 additions & 1 deletion kvpress/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,10 @@ def preprocess(
else:
separator = "\n" + "#" * len(context)
context = self.tokenizer.apply_chat_template(
[{"role": "user", "content": context + separator}], add_generation_prompt=True, tokenize=False
[{"role": "user", "content": context + separator}],
add_generation_prompt=True,
tokenize=False,
enable_thinking=False,
)
context, question_suffix = context.split(separator)

Expand Down
7 changes: 4 additions & 3 deletions kvpress/presses/block_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ class BlockPress(BasePress):
BlockPress: Block-wise iterative KV cache compression.

Applies compression in fixed-size blocks. Iteratively scores and prunes tokens block by block, maintaining
a buffer of previously kept tokens for context. Mathematically equivalent
to global compression when scoring uses only local information.
a buffer of previously kept tokens for context. Mathematically equivalent to global compression when
scoring uses only local information. It was introduced in the KeyDiff paper as part of the KeyDiff press,
but it can also work as a standalone press.

Based on BlockPress (https://arxiv.org/abs/2504.15364).
Based on the KeyDiff paper (https://arxiv.org/abs/2504.15364).

Parameters
----------
Expand Down
6 changes: 6 additions & 0 deletions kvpress/presses/keydiff_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ class KeyDiffPress(ScorerPress):

Based on KeyDiff (https://arxiv.org/abs/2504.15364).

Note: The original press in the KeyDiff paper implements a block-wise iterative compression.
In KVPress, the iterative compression is implemented in the BlockPress class.
Therefore, to replicate the paper's implementation, please use:

`press = BlockPress(press=KeyDiffPress(compression_ratio=0.x), block_size=N)`

Parameters
----------
compression_ratio : float, default=0.0
Expand Down
5 changes: 4 additions & 1 deletion kvpress/presses/kvzip_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ def __call__(self, model: PreTrainedModel) -> Generator:
dummy_context = "dummy context"
separator = "\n" + "#" * len(dummy_context)
temp_context = tokenizer.apply_chat_template(
[{"role": "user", "content": dummy_context + separator}], add_generation_prompt=True, tokenize=False
[{"role": "user", "content": dummy_context + separator}],
add_generation_prompt=True,
tokenize=False,
enable_thinking=False,
)
context, suffix_text = temp_context.split(separator)
prefix_text = context.split(dummy_context)[0]
Expand Down
96 changes: 56 additions & 40 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,50 +1,66 @@
[tool.poetry]
[project]
name = "kvpress"
authors = ["Simon Jegou", "Maximilian Jeblick", "Alessio Devoto", "Jiwei Liu", "David Austin"]
version = "0.2.10"
description = "Efficiently compress the KV cache of any pretrained transformer"
version = "0.2.9"
authors = [
{ name = "Simon Jegou" },
{ name = "Maximilian Jeblick" },
{ name = "Alessio Devoto" },
{ name = "Jiwei Liu" },
{ name = "David Austin" },
]
requires-python = ">=3.10"
readme = "README.md"
dependencies = [
"numpy>=2.0.0,<3",
"torch>=2.3.1,<3",
"transformers>=4.48.0, <4.54.0",
"sentencepiece>=0.2.0,<0.3",
"protobuf>=5.27.2,<6",
"datasets>=2.21.0,<3",
"pandas>=2.2.2,<3",
"accelerate>=1.0.0,<2",
"requests>=2.32.3,<3",
"cachetools>=5.5.2,<6",
]

[tool.poetry.dependencies]
python = ">=3.10"
numpy = "^2.0.0"
torch = "^2.3.1"
transformers = ">=4.48.0, <4.54.0"
sentencepiece = "^0.2.0"
protobuf = "^5.27.2"
datasets = "^2.21.0"
pandas = "^2.2.2"
accelerate = "^1.0.0"
requests = "^2.32.3"
cachetools = "^5.5.2"
[project.optional-dependencies]
eval = [
"rouge>=1.0.1,<2",
"nltk>=3.9.1,<4",
"tqdm>=4.66.4,<5",
"scipy>=1.13.1,<2",
"fire>=0.6.0,<0.7",
"bert-score>=0.3.13,<0.4",
]
flash-attn = [
"flash-attn"
]

[tool.poetry.group.dev]
optional = true
[dependency-groups]
dev = [
"pytest>=7.0.0,<8",
"flake8>=7.0.0,<8",
"isort>=5.13.2,<6",
"black>=24.8.0,<25",
"mypy>=1.13.0,<2",
"pytest-cov>=5.0.0,<6",
"pytest-dependency>=0.6.0,<0.7",
"pytest-html>=4.1.1, <5.0.0",
"types-pyyaml~=6.0",
"ipykernel>=6.29.4,<7",
"bs4>=0.0.2,<0.0.3",
"nvitop>=1.3.2,<2",
"matplotlib>=3.9.0,<4",
]

[tool.poetry.group.dev.dependencies]
pytest = "^7.0.0"
flake8 = "^7.0.0"
isort = "^5.13.2"
black = "^24.8.0"
mypy = "^1.13.0"
pytest-cov = "^5.0.0"
pytest-dependency = "^0.6.0"
pytest-html = ">=4.1.1, <5.0.0"
types-pyyaml = "^6.0"
ipykernel = "^6.29.4"
bs4 = "^0.0.2"
nvitop = "^1.3.2"
bert-score = "^0.3.13"
rouge = "^1.0.1"
nltk = "^3.9.1"
tqdm = "^4.66.4"
scipy = "^1.13.1"
matplotlib = "^3.9.0"
fire = "^0.6.0"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.uv]
no-build-isolation-package = ["flash-attn"]

[tool.black]
line-length = 120
Expand All @@ -64,7 +80,7 @@ skip = ["venv", ".venv"]
ignore_missing_imports = true
allow_redefinition = true
strict_optional = false
exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|venv|.venv|doc-venv|.svn|_build|buck-out|build|dist|notebooks|tools|tmp|tests|bundles)"
exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|venv|.venv|doc-venv|.svn|_build|buck-out|build|dist|notebooks|tools|tmp|tests|bundles|.pytest_cache|reports)"
disable_error_code = ["union-attr", "operator", "call-overload", "arg-type"]

[[tool.mypy.overrides]]
Expand Down