Skip to content

Commit 07e8148

Browse files
authored
[GenAI] Support Token Eviction for LRMs (#1012)
* [GenAI] Support refined algorithms (KVCrush, DiverseKV) * [GenAI] Support Token Eviction for LRMs * minor fixes
1 parent ad5bb54 commit 07e8148

File tree

9 files changed

+1630
-50
lines changed

9 files changed

+1630
-50
lines changed

modules/genai_optimizations/README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ This module provides experimental optimizations for GenAI models in PyTorch. The
66

77
- Text Generation Using LLMs
88
- Visual language text generation
9+
- Reasoning and Problem Solving
910

1011
## Supported Generative AI Optimization Methods
1112

@@ -34,6 +35,14 @@ This module provides experimental optimizations for GenAI models in PyTorch. The
3435
Paper: https://arxiv.org/pdf/2306.14048
3536
- **SnapKV Mode** – Modifies the *H2O* approach by computing token importance within a small sliding window of the most recent queries during the prefill stage, then reverting to the H2O strategy during decoding. The authors observed that only a small subset of prompt tokens is sufficient for accurate response generation.
3637
Paper: https://arxiv.org/pdf/2404.14469
38+
- **RKV Mode** - Computes token importance scores based on attention weights over a sliding window of the most recent queries during both the prefill and decode stages. Importance scores are stabilized using per-token max-pooling and then averaged across attention heads.
39+
40+
Refined modes enhance standard eviction strategies by selecting the most representative tokens or blocks from the evictable (intermediate) region. These methods aim to balance contextual importance with redundancy reduction to optimize cache efficiency. If `refined_algorithm` is enabled but `refined_tokens` is not specified or set to 0, the number of refined tokens is determined dynamically as part of the intermediate token budget. Budget for primary algorithm is allocated by selecting the minimal number of tokens or groups that together capture at least 90% of the total attention mass, ensuring that all high-importance tokens are retained. For the remaining eviction budget, each token’s dissimilarity is computed relative to the already retained set, promoting information diversity and reducing redundancy.
41+
42+
Supported refined modes:
43+
- **KVCrush Mode** - Selects representative blocks based on diversity rather than raw importance. This is achieved by generating binary indicators for each token, constructing an anchor point (reference pattern) using one of several modes: `random`, `zeros`, `ones`, `mean`, `alternate`, and selecting blocks with the highest Hamming distance to the anchor point.
44+
Paper: https://arxiv.org/pdf/2503.00022
45+
- **DiverseKV Mode** – Implements a dynamic redundancy scoring mechanism to identify and de-prioritize repetitive tokens based on cosine similarity of key vectors with already retained tokens. Key vectors are normalized, and cosine similarities are computed with diagonal values zeroed to avoid self-similarity. Similarities are thresholded on a per-head basis—only values greater than or equal to the mean similarity for each head are kept and then aggregated across heads. For the remaining eviction budget, each token or group's dissimilarity to already retained tokens or groups is calculated. Tokens/groups with the highest dissimilarity scores are retained, maximizing contextual diversity while reducing redundancy.
3746

3847
## Supported and tested models
3948

@@ -53,6 +62,12 @@ Multimodal Large Language Models:
5362
- [Qwen/Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
5463
- [Qwen/Qwen2-VL-2B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct)
5564

65+
Large Reasoning Models:
66+
67+
- [deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B)
68+
- [Qwen/Qwen3-1.7B](https://huggingface.co/Qwen/Qwen3-1.7B)
69+
- [microsoft/Phi-4-mini-reasoning](https://huggingface.co/microsoft/Phi-4-mini-reasoning)
70+
5671
## Prerequisites
5772

5873
Before running algorithms, ensure you have **Python 3.10+** installed and set up your environment.

modules/genai_optimizations/benchmarks/README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ This [example](./longbench.py) demonstrates how to evaluate and optimize LLMs us
1010

1111
Sparse attention speeds up the prefill stage in LLMs by attending only to the most relevant query-key blocks. Static patterns like Tri-Shape and dynamic mechanisms like XAttention reduce memory and computation without significant accuracy loss, enabling efficient handling of long prompts.
1212

13+
KV-Cache Token Eviction accelerates the decoding stage in LLMs by removing less important cached tokens while preserving those essential for contextual understanding, allowing efficient long-sequence inference under constrained memory.
14+
1315
### Run Example
1416

1517
```bash
@@ -100,3 +102,32 @@ This will automatically:
100102
- Evaluate the model and report the score
101103

102104
</details>
105+
106+
<details>
107+
<summary><b>Large Reasoning Models Optimization Example: MATH500 and GSM8K Benchmarks</b></summary>
108+
109+
This [example](./math500_gsm_bench.py) demonstrates how to evaluate and optimize LRMs using the KV-Cache Token Eviction algorithm. The example leverages [MATH500](https://huggingface.co/datasets/HuggingFaceH4/MATH-500) and [GSM8K](https://huggingface.co/datasets/openai/gsm8k) datasets.
110+
MATH500 contains a subset of 500 problems from the [MATH](https://github.com/hendrycks/math) benchmark, originally introduced in OpenAI’s Let’s Verify Step by Step paper. The subset covers six domains: algebra, geometry, intermediate algebra, number theory, precalculus, and probability.
111+
GSM8K (Grade School Math 8K) is a dataset of 8,500 high-quality, linguistically diverse grade-school math word problems. While the problems are conceptually simple, they often require multi-step reasoning, making them challenging for state-of-the-art language models due to the high diversity of problems.
112+
113+
114+
### Run Example
115+
116+
```bash
117+
python math500_gsm_bench.py \
118+
--dataset MATH500 \
119+
--model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
120+
--max_tokens 5000 \
121+
--max_examples 100 \
122+
--enable_eviction \
123+
--algorithm rkv \
124+
--granularity per_group \
125+
--intermediate_tokens 512
126+
```
127+
This will automatically:
128+
129+
- Download the selected model and dataset
130+
- Apply token eviction during the decoding stage
131+
- Evaluate the model and report the score
132+
133+
</details>
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
# Copyright (C) 2018-2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# This logic is largely copied from the
5+
# - https://github.com/microsoft/ProphetNet/tree/master/CRITIC
6+
# - https://github.com/openai/prm800k
7+
# - https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py
8+
# - https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py
9+
# - https://github.com/VITA-Group/SEAL/tree/main
10+
11+
import argparse
12+
import json
13+
import os
14+
import random
15+
import re
16+
from collections import Counter
17+
from contextlib import ExitStack
18+
19+
from datasets import load_dataset
20+
from tqdm import tqdm
21+
from transformers import AutoModelForCausalLM
22+
from transformers import AutoTokenizer
23+
24+
from utils import add_attention_args, add_token_eviction_args
25+
from utils import get_eviction_patcher, get_sparse_attention_patcher
26+
27+
from reasoning_parser import extract_answer
28+
from reasoning_parser import parallel_math_equal
29+
from reasoning_parser import strip_string
30+
31+
# disable tokenizer parallelism warnings
32+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
33+
OUTPUT_LENGTHS = []
34+
35+
36+
def run_evaluation(res_path, save=False, k=None, output_dir=None):
37+
with open(res_path) as f:
38+
lines = f.readlines()
39+
data = [json.loads(line) for line in lines]
40+
41+
for example in tqdm(data):
42+
if "model_generation" not in example:
43+
example["model_generation"] = example["model_output"]
44+
if k is not None:
45+
example["model_generation"] = example["model_generation"][:k]
46+
gt_cot = example["answer"]
47+
gt_ans = extract_answer(gt_cot, data_name="omni-math")
48+
gt_cot = str(gt_cot).strip()
49+
gt_ans = strip_string(gt_ans, skip_unit=False)
50+
all_pred = [extract_answer(p, data_name="omni-math") for p in example["model_generation"]]
51+
all_pred = [strip_string(p, skip_unit=False) for p in all_pred]
52+
all_eval = parallel_math_equal(all_pred, gt_ans, timeout=5)
53+
effective_pred = [p for p, o in zip(all_pred, example["model_generation"]) if "boxed" in o]
54+
if len(effective_pred) == 0:
55+
effective_pred = all_pred
56+
counter = Counter(effective_pred)
57+
pred = counter.most_common(1)[0][0]
58+
index = all_pred.index(pred)
59+
eval = all_eval[index]
60+
example["all_pred"] = all_pred
61+
example["all_eval"] = all_eval
62+
example["mv_pred"] = pred
63+
example["mv_eval"] = eval
64+
example["mv_index"] = index
65+
66+
acc = sum([example["mv_eval"] for example in data]) / len(data)
67+
print(f"Accuracy: {acc:.3f}")
68+
69+
correct_avg_len = []
70+
incorrect_avg_len = []
71+
72+
for i, example in enumerate(data):
73+
if example["mv_eval"]:
74+
correct_avg_len.append(OUTPUT_LENGTHS[i])
75+
else:
76+
incorrect_avg_len.append(OUTPUT_LENGTHS[i])
77+
78+
if len(correct_avg_len) != 0:
79+
print(f"Correct avg len: {sum(correct_avg_len) / len(correct_avg_len):.2f}", end=", ")
80+
if len(incorrect_avg_len) != 0:
81+
print(f"Incorrect avg len: {sum(incorrect_avg_len) / len(incorrect_avg_len):.2f}")
82+
83+
if save:
84+
out_file = os.path.join(output_dir, "math_eval.jsonl")
85+
with open(out_file, "w") as f:
86+
for example in data:
87+
f.write(json.dumps(example) + "\n")
88+
89+
metric_file = os.path.join(output_dir, "metrics.json")
90+
with open(metric_file, "w") as f:
91+
json.dump({"acc": acc}, f)
92+
93+
94+
def trim_output(output):
95+
instruction_prefix = "Answer the following question"
96+
question_prefix = "Question:"
97+
comment_prefix = "Comment:" # for some reason, Llama 13B likes to generate these comments indefinitely
98+
99+
for prefix in [instruction_prefix, question_prefix, comment_prefix]:
100+
if prefix in output:
101+
output = output.split(prefix)[0]
102+
103+
return output
104+
105+
106+
def extract_box(pred_str):
107+
ans = pred_str.split("boxed")[-1]
108+
if len(ans) == 0:
109+
return ""
110+
elif ans[0] == "{":
111+
stack = 1
112+
a = ""
113+
for c in ans[1:]:
114+
if c == "{":
115+
stack += 1
116+
a += c
117+
elif c == "}":
118+
stack -= 1
119+
if stack == 0:
120+
break
121+
a += c
122+
else:
123+
a += c
124+
else:
125+
a = ans.split("$")[0].strip()
126+
127+
return a
128+
129+
130+
def prepare_dataset(dataset, max_samples=None):
131+
test_data = []
132+
if dataset == "MATH500":
133+
data = load_dataset("HuggingFaceH4/MATH-500", split="test")
134+
for example in data:
135+
gt = extract_box(example["solution"])
136+
test_data.append(
137+
{
138+
"question": example["problem"],
139+
"answer": example["solution"],
140+
"gt": gt,
141+
}
142+
)
143+
elif dataset == "GSM":
144+
data_path = "gsm.jsonl"
145+
146+
if not os.path.exists(data_path):
147+
import requests
148+
url = "https://raw.githubusercontent.com/VITA-Group/SEAL/main/data/gsm/test.jsonl"
149+
response = requests.get(url)
150+
response.raise_for_status()
151+
with open(data_path, "w", encoding="utf-8") as f:
152+
f.write(response.text)
153+
print(f"Downloaded and saved to '{data_path}'.")
154+
155+
with open(data_path) as fin:
156+
for line in fin:
157+
example = json.loads(line)
158+
answer = example["answer"].split("####")[1].strip()
159+
answer = re.sub(r"(\d),(\d)", r"\1\2", answer)
160+
test_data.append(
161+
{
162+
"question": example["question"],
163+
"answer": example["answer"].split("####")[0].strip(),
164+
"gt": answer,
165+
}
166+
)
167+
168+
if max_samples and len(test_data) > max_samples:
169+
test_data = test_data[:max_samples]
170+
171+
return test_data
172+
173+
174+
def main(args):
175+
random.seed(42)
176+
177+
test_data = prepare_dataset(args.dataset, max_samples=args.max_examples)
178+
179+
tokenizer = AutoTokenizer.from_pretrained(args.model)
180+
# set pad token to eos token if pad token is not set (as is the case for llama models)
181+
if tokenizer.pad_token is None:
182+
tokenizer.pad_token = tokenizer.eos_token
183+
tokenizer.pad_token_id = tokenizer.eos_token_id
184+
185+
prefix = (
186+
"Answer the following questions. You should think step-by-step and put your final answer within \\boxed{}.\n"
187+
)
188+
prompts = []
189+
for example in test_data:
190+
prompt = prefix + "Question: " + example["question"].strip() + "\nAnswer: "
191+
if not args.omit_chat_template:
192+
if "deepseek" in args.model:
193+
messages = [{"role": "user", "content": prefix + "Question: " + example["question"].strip()}]
194+
else:
195+
messages = [
196+
{"role": "system", "content": prefix},
197+
{"role": "user", "content": "Question: " + example["question"].strip()},
198+
]
199+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
200+
if not args.keep_bos and tokenizer.bos_token is not None and prompt.startswith(tokenizer.bos_token):
201+
prompt = prompt[len(tokenizer.bos_token) :]
202+
prompts.append(prompt)
203+
204+
kwargs = {"temperature": None, "top_p": None, "top_k": None}
205+
# force attn_implementation="eager" when using token eviction without custom attention
206+
if args.enable_eviction and not args.use_custom_attention:
207+
kwargs["attn_implementation"] = "eager"
208+
209+
model = AutoModelForCausalLM.from_pretrained(
210+
args.model,
211+
trust_remote_code=True,
212+
device_map="auto",
213+
token=os.environ.get("HF_TOKEN", None),
214+
**kwargs
215+
)
216+
model.eval()
217+
218+
contexts = []
219+
if args.use_custom_attention:
220+
sparse_attn = get_sparse_attention_patcher(args)
221+
contexts.append(sparse_attn)
222+
223+
if args.enable_eviction:
224+
token_eviction = get_eviction_patcher(args)
225+
contexts.append(token_eviction)
226+
227+
outputs = []
228+
avg_prompt_len = []
229+
with ExitStack() as stack:
230+
for ctx in contexts:
231+
if ctx is not None:
232+
stack.enter_context(ctx(model))
233+
234+
for prompt in tqdm(prompts):
235+
tokenized_batch = tokenizer(prompt, return_tensors="pt", padding=True)
236+
tokenized_batch = {k: v.to(model.device) for k, v in tokenized_batch.items()}
237+
avg_prompt_len.append(tokenized_batch["input_ids"].shape[1])
238+
239+
output = model.generate(
240+
**tokenized_batch,
241+
do_sample=False,
242+
max_new_tokens=args.max_tokens,
243+
use_cache=True,
244+
pad_token_id=tokenizer.eos_token_id,
245+
)
246+
OUTPUT_LENGTHS.append(output.shape[1])
247+
output = [tokenizer.decode(o[avg_prompt_len[-1]:], skip_special_tokens=True) for o in output]
248+
outputs.extend(output)
249+
250+
outputs = [[trim_output(o)] for o in outputs]
251+
print(f"Average prompt length: {sum(avg_prompt_len) / len(avg_prompt_len):.2f}")
252+
print(f"Average length: {sum(OUTPUT_LENGTHS) / len(OUTPUT_LENGTHS):.2f}")
253+
254+
predictions = [
255+
{
256+
"prompt": prompt,
257+
"problem": example["question"],
258+
"answer": example["gt"],
259+
"solution": example["answer"],
260+
"model_generation": output,
261+
}
262+
for example, output, prompt in zip(test_data, outputs, prompts)
263+
]
264+
265+
with open(os.path.join(args.save_dir, "predictions.jsonl"), "w") as fout:
266+
for prediction in predictions:
267+
fout.write(json.dumps(prediction) + "\n")
268+
269+
270+
if __name__ == "__main__":
271+
parser = argparse.ArgumentParser()
272+
parser.add_argument("--model", type=str, required=True)
273+
parser.add_argument("--dataset", type=str, default="MATH500", choices=["MATH500", "GSM"])
274+
parser.add_argument("--max_examples", type=int, default=None)
275+
parser.add_argument("--start", type=int, default=None)
276+
parser.add_argument("--save_dir", type=str, default="results")
277+
parser.add_argument("--max_tokens", type=int, default=5000)
278+
parser.add_argument("--omit_chat_template", action="store_true")
279+
parser.add_argument("--keep_bos", action="store_true")
280+
281+
add_attention_args(parser)
282+
add_token_eviction_args(parser)
283+
args = parser.parse_args()
284+
285+
args.save_dir = os.path.join(args.save_dir, args.dataset)
286+
if args.keep_bos:
287+
args.save_dir = args.save_dir + "_keep_bos"
288+
289+
if args.max_examples or args.start:
290+
start = 0 if args.start is None else args.start
291+
end = start + args.max_examples if args.max_examples is not None else -1
292+
args.save_dir = os.path.join(args.save_dir, f"{start}_{end}")
293+
294+
if not os.path.exists(args.save_dir):
295+
os.makedirs(args.save_dir)
296+
297+
print(f"Results will be saved to {args.save_dir}")
298+
main(args)
299+
run_evaluation(os.path.join(args.save_dir, "predictions.jsonl"), output_dir=args.save_dir)

0 commit comments

Comments
 (0)