Skip to content

Commit fda6fb5

Browse files
authored
[GenAI] Support Token Eviction (#1010)
* [GenAI] Support Token Eviction * Update score aggragation
1 parent 07ee427 commit fda6fb5

File tree

8 files changed

+481
-42
lines changed

8 files changed

+481
-42
lines changed

modules/genai_optimizations/README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,24 @@ This module provides experimental optimizations for GenAI models in PyTorch. The
1717
- **Tri-Shape Mode** – A static block-sparse attention pattern that preserves the initial tokens, local windows, and the final segment of the query, forming a triangular structure to capture critical tokens while maintaining instruction-following performance in both turn-0 and multi-request scenarios. Paper: https://arxiv.org/pdf/2412.10319
1818
- **XAttention Mode** – A dynamic block-sparse attention mechanism that accelerates inference by focusing computation on the most important regions of the attention matrix using antidiagonal block scoring, reducing FLOPs and memory usage without significant loss of accuracy. Paper: https://arxiv.org/pdf/2503.16428
1919

20+
- [**KV Cache Token Eviction**](./token_eviction.py):
21+
Designed to optimize KV cache memory usage during autoregressive generation in LLMs. It selectively removes less important cached tokens while preserving those crucial for contextual understanding, enabling efficient long-sequence inference under constrained memory. Note that currently eviction starts only after the full prompt has been processed; i.e., no eviction takes place during the prefill phase.
22+
23+
The KV cache is split into three parts: **start**, **intermediate (evictable)**, and **recent**. The size of each part is configurable:
24+
- **Start Area** – Initial tokens that are never evicted.
25+
- **Intermediate Area** – Tokens that can be evicted based on importance scores.
26+
- **Recent Area** – Most recent tokens that are preserved (not evicted while in this area, but naturally migrate toward the evictable area as text generation continues).
27+
28+
Eviction granularity can be **per-token** or **per-group**:
29+
- **Per-token** – Tokens are evicted independently from the KV cache.
30+
- **Per-group** – Only fully filled blocks from the evictable area are removed. Tokens are managed in consecutive, non-overlapping groups, following the concept of *Paged Attention*, which organizes the KV cache into pages. Each token belongs to a single page and remains there for the entire generation process. To maximize eviction efficiency, entire pages are evicted rather than individual tokens. The `group_size` is a configurable algorithm parameter.
31+
32+
Supported modes:
33+
- **H2O Mode** – Evicts tokens using the *Heavy-Hitter Oracle* strategy, which accumulates attention scores to identify and retain high-impact tokens. It also preserves recent tokens due to their strong correlation with the current context. Scores are accumulated throughout the entire generation process, and their weighting can be adjusted via the `normalize_scores` parameter, which controls whether attention scores are normalized by the number of times each token was attended to.
34+
Paper: https://arxiv.org/pdf/2306.14048
35+
- **SnapKV Mode** – Modifies the *H2O* approach by computing token importance within a small sliding window of the most recent queries during the prefill stage, then reverting to the H2O strategy during decoding. The authors observed that only a small subset of prompt tokens is sufficient for accurate response generation.
36+
Paper: https://arxiv.org/pdf/2404.14469
37+
2038
## Supported and tested models
2139

2240
Large Language Models:

modules/genai_optimizations/benchmarks/README.md

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,18 @@ python longbench.py \
1717
--subset samsum \
1818
--model meta-llama/Llama-3.2-1B-Instruct \
1919
--use_custom_attention \
20-
--prefill_impl tri-shape
20+
--prefill_impl tri-shape \
21+
--enable_eviction \
22+
--algorithm h2o \
23+
--granularity per_group \
24+
--normalize_scores \
25+
--intermediate_tokens 1024
2126
```
2227
This will automatically:
2328

2429
- Download the selected model and dataset
2530
- Apply sparse attention computation during the prefill stage
31+
- Apply token eviction during the decoding stage
2632
- Evaluate the model and report the score
2733

2834
</details>
@@ -46,13 +52,18 @@ python mmebench.py \
4652
--num_keep_tokens 128 \
4753
--theta 0.5 \
4854
--use_custom_attention \
49-
--prefill_impl x-attention
55+
--prefill_impl x-attention \
56+
--enable_eviction \
57+
--algorithm snapkv \
58+
--granularity per_group \
59+
--window_size 8
5060
```
5161
This will automatically:
5262

5363
- Download the selected model and dataset
5464
- Apply the visual token pruning algorithm
5565
- Apply sparse attention computation during the prefill stage
66+
- Apply token eviction during the decoding stage
5667
- Evaluate the model and report the score
5768

5869
</details>
@@ -73,14 +84,19 @@ python milebench.py \
7384
--num_keep_tokens 64 \
7485
--theta 0.5 \
7586
--use_custom_attention \
76-
--prefill_impl tri-shape
87+
--prefill_impl tri-shape \
88+
--enable_eviction \
89+
--algorithm snapkv \
90+
--granularity per_group \
91+
--window_size 8
7792
```
7893

7994
This will automatically:
8095

8196
- Download the selected model and dataset
8297
- Apply the visual token pruning algorithm
8398
- Apply sparse attention computation during the prefill stage
99+
- Apply token eviction during the decoding stage
84100
- Evaluate the model and report the score
85101

86102
</details>

modules/genai_optimizations/benchmarks/longbench.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from transformers import AutoModelForCausalLM
2020
from transformers import AutoTokenizer
2121

22-
from genai_opt import SparseAttention
23-
from utils import add_attention_args
22+
from utils import add_attention_args, add_token_eviction_args
23+
from utils import get_eviction_patcher, get_sparse_attention_patcher
2424

2525
# (Phi3 and DeepSeek issue)
2626
# AttributeError: 'DynamicCache' object has no attribute 'get_max_length'. Did you mean: 'get_seq_length'?
@@ -315,28 +315,29 @@ def evaluate(args):
315315
args.model, trust_remote_code=True, token=os.environ.get("HF_TOKEN", None)
316316
)
317317

318+
kwargs = {"temperature": None, "top_p": None, "top_k": None}
319+
# force attn_implementation="eager" when using token eviction without custom attention
320+
if args.enable_eviction and not args.use_custom_attention:
321+
kwargs["attn_implementation"] = "eager"
322+
318323
model = AutoModelForCausalLM.from_pretrained(
319324
args.model,
320-
# attn_implementation="eager",
321325
trust_remote_code=True,
322326
dtype=torch.float16,
323327
device_map="auto",
324328
token=os.environ.get("HF_TOKEN", None),
325-
temperature=None,
326-
top_p=None,
327-
top_k=None,
329+
**kwargs,
328330
).eval()
329331

330332
patchers = []
331333
if args.use_custom_attention:
332-
sparse_attn = SparseAttention(
333-
algorithm=args.prefill_impl,
334-
threshold=args.threshold,
335-
recent_size=args.recent_size,
336-
last_query_size=args.last_query_size,
337-
)
334+
sparse_attn = get_sparse_attention_patcher(args)
338335
patchers.append(sparse_attn)
339336

337+
if args.enable_eviction:
338+
token_eviction = get_eviction_patcher(args)
339+
patchers.append(token_eviction)
340+
340341
max_new_tokens = dataset.get_max_new_tokens()
341342
answers = []
342343
max_length = 4500
@@ -391,6 +392,7 @@ def evaluate(args):
391392
parser.add_argument("--model", type=str, required=True, help="Model name")
392393

393394
add_attention_args(parser)
395+
add_token_eviction_args(parser)
394396
args = parser.parse_args()
395397

396398
evaluate(args)

modules/genai_optimizations/benchmarks/milebench.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
from transformers import AutoProcessor
2121

2222
from logging import getLogger
23-
from genai_opt import get_inputs_embeds, SparseAttention
24-
from utils import add_attention_args, add_visual_pruning_args
23+
from genai_opt import get_inputs_embeds
24+
from utils import add_attention_args, add_visual_pruning_args, add_token_eviction_args
25+
from utils import get_eviction_patcher, get_sparse_attention_patcher
2526

2627

2728
logger = getLogger(__name__)
@@ -454,21 +455,25 @@ def get_model_class(model_name):
454455

455456
add_visual_pruning_args(parser)
456457
add_attention_args(parser)
458+
add_token_eviction_args(parser)
457459
args = parser.parse_args()
458460

459461
dataset = MileBenchDataset(data_dir=args.data_dir, subset=args.subset)
460462
processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True)
461463
model_cls = get_model_class(args.model)
464+
465+
kwargs = {"temperature": None, "top_p": None, "top_k": None}
466+
# force attn_implementation="eager" when using token eviction without custom attention
467+
if args.enable_eviction and not args.use_custom_attention:
468+
kwargs["attn_implementation"] = "eager"
469+
462470
model = model_cls.from_pretrained(
463471
args.model,
464-
# attn_implementation="eager",
465472
trust_remote_code=True,
466473
dtype=torch.bfloat16,
467474
device_map="auto",
468475
token=os.environ.get("HF_TOKEN", None),
469-
temperature=None,
470-
top_p=None,
471-
top_k=None,
476+
**kwargs
472477
)
473478
model = model.eval()
474479

@@ -482,12 +487,11 @@ def get_model_class(model_name):
482487

483488
contexts = []
484489
if args.use_custom_attention:
485-
sparse_prefill = SparseAttention(
486-
algorithm=args.prefill_impl,
487-
threshold=args.threshold,
488-
recent_size=args.recent_size,
489-
last_query_size=args.last_query_size,
490-
)
491-
contexts.append(sparse_prefill)
490+
sparse_attn = get_sparse_attention_patcher(args)
491+
contexts.append(sparse_attn)
492+
493+
if args.enable_eviction:
494+
token_eviction = get_eviction_patcher(args)
495+
contexts.append(token_eviction)
492496

493497
evaluate(dataset, processor, model, num_keep_tokens=num_keep_tokens, theta=theta, contexts=contexts)

modules/genai_optimizations/benchmarks/mmebench.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
from transformers import AutoProcessor
1818
from transformers import set_seed
1919

20-
from genai_opt import get_inputs_embeds, SparseAttention
21-
from utils import add_attention_args, add_visual_pruning_args
20+
from genai_opt import get_inputs_embeds
21+
from utils import add_attention_args, add_visual_pruning_args, add_token_eviction_args
22+
from utils import get_eviction_patcher, get_sparse_attention_patcher
2223

2324

2425
class MetricCalculator:
@@ -89,16 +90,19 @@ def evaluate(args):
8990

9091
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
9192
model_cls = get_model_class(model_name)
93+
94+
kwargs = {"temperature": None, "top_p": None, "top_k": None}
95+
# force attn_implementation="eager" when using token eviction without custom attention
96+
if args.enable_eviction and not args.use_custom_attention:
97+
kwargs["attn_implementation"] = "eager"
98+
9299
model = model_cls.from_pretrained(
93100
model_name,
94101
trust_remote_code=True,
95-
# attn_implementation="eager",
96102
dtype=torch.bfloat16,
97103
device_map="auto",
98104
token=os.environ.get("HF_TOKEN", None),
99-
temperature=None,
100-
top_p=None,
101-
top_k=None,
105+
**kwargs
102106
).eval()
103107

104108
if args.enable_visual_pruning:
@@ -111,15 +115,13 @@ def evaluate(args):
111115

112116
contexts = []
113117
if args.use_custom_attention:
114-
print(f"Enable custom attention kernel with {args.prefill_impl} implementation")
115-
sparse_prefill = SparseAttention(
116-
algorithm=args.prefill_impl,
117-
threshold=args.threshold,
118-
recent_size=args.recent_size,
119-
last_query_size=args.last_query_size,
120-
)
118+
sparse_prefill = get_sparse_attention_patcher(args)
121119
contexts.append(sparse_prefill)
122120

121+
if args.enable_eviction:
122+
token_eviction = get_eviction_patcher(args)
123+
contexts.append(token_eviction)
124+
123125
all_items = []
124126
with ExitStack() as stack:
125127
for ctx in contexts:
@@ -230,6 +232,7 @@ def get_model_class(model_name):
230232

231233
add_visual_pruning_args(parser)
232234
add_attention_args(parser)
235+
add_token_eviction_args(parser)
233236
args = parser.parse_args()
234237

235238
evaluate(args)

modules/genai_optimizations/benchmarks/utils.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (C) 2018-2025 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
3-
3+
from genai_opt import SparseAttention
4+
from genai_opt import KVCacheCompressionMode, KVCacheCompressionParameters, KVCacheCompressor
45

56
def add_visual_pruning_args(parser):
67
group = parser.add_argument_group("Visual Token Pruning Arguments")
@@ -28,3 +29,53 @@ def add_attention_args(parser):
2829
help="Window size of recent tokens each query can attend to in the Tri-shape pattern"
2930
)
3031
return parser
32+
33+
34+
def add_token_eviction_args(parser):
35+
group = parser.add_argument_group("Token Eviction Arguments")
36+
group.add_argument("--enable_eviction", action="store_true", help="Enable token eviction")
37+
group.add_argument("--algorithm", default="snapkv", choices=["snapkv", "h2o"], help="The KV cache eviction algorithm")
38+
group.add_argument("--granularity", default="per_group", choices=["per_token", "per_group"], help="Eviction granularity")
39+
group.add_argument(
40+
"--normalize_scores",
41+
action="store_true",
42+
help="Whether to normalize the attention scores by the number of times each token was attended to."
43+
)
44+
group.add_argument(
45+
"--start_tokens",
46+
type=int,
47+
default=32,
48+
help="The number of tokens in the beginning of the cache (least recent) to be retained"
49+
)
50+
group.add_argument("--intermediate_tokens", type=int, default=1024, help="The number of intermediate tokens to consider for eviction")
51+
group.add_argument("--recent_tokens", type=int, default=128, help="The number of most recent tokens to be retained")
52+
group.add_argument("--group_size", type=int, default=32, help="Group size for per-group eviction strategy")
53+
group.add_argument("--window_size", type=int, default=None, help="The size of the importance score aggregation window")
54+
return parser
55+
56+
57+
def get_sparse_attention_patcher(args):
58+
print(f"Enable custom attention kernel with {args.prefill_impl} implementation")
59+
return SparseAttention(
60+
algorithm=args.prefill_impl,
61+
threshold=args.threshold,
62+
recent_size=args.recent_size,
63+
last_query_size=args.last_query_size,
64+
output_attentions=args.enable_eviction, # output attention weights only if eviction is enabled
65+
)
66+
67+
68+
def get_eviction_patcher(args):
69+
print(f"Enable token eviction with {args.algorithm} algorithm")
70+
algorithm = KVCacheCompressionMode(args.algorithm)
71+
params = KVCacheCompressionParameters(
72+
algorithm=algorithm,
73+
granularity=args.granularity,
74+
group_size=args.group_size,
75+
start_tokens=args.start_tokens,
76+
recent_tokens=args.recent_tokens,
77+
intermediate_tokens=args.intermediate_tokens,
78+
normalize_scores=args.normalize_scores,
79+
window_size=args.window_size,
80+
)
81+
return KVCacheCompressor(eviction_parameters=params)

modules/genai_optimizations/genai_opt/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33

44
from genai_opt.visual_token_pruning import get_inputs_embeds
55
from genai_opt.sparse_attention import SparseAttention
6+
from genai_opt.token_eviction import KVCacheCompressionMode, KVCacheCompressionParameters, KVCacheCompressor

0 commit comments

Comments
 (0)