Skip to content

Commit 6d10335

Browse files
committed
[GenAI] Support refined algorithms (KVCrush, DiverseKV)
1 parent 61d3193 commit 6d10335

File tree

7 files changed

+1524
-33
lines changed

7 files changed

+1524
-33
lines changed

modules/genai_optimizations/benchmarks/README.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ This [example](./longbench.py) demonstrates how to evaluate and optimize LLMs us
1010

1111
Sparse attention speeds up the prefill stage in LLMs by attending only to the most relevant query-key blocks. Static patterns like Tri-Shape and dynamic mechanisms like XAttention reduce memory and computation without significant accuracy loss, enabling efficient handling of long prompts.
1212

13+
KV-Cache Token Eviction accelerates the decoding stage in LLMs by removing less important cached tokens while preserving those essential for contextual understanding, allowing efficient long-sequence inference under constrained memory.
14+
1315
### Run Example
1416

1517
```bash
@@ -100,3 +102,30 @@ This will automatically:
100102
- Evaluate the model and report the score
101103

102104
</details>
105+
106+
<details>
107+
<summary><b>Large Reasoning Models Optimization Example: MATH500 and GSM8K Benchmarks</b></summary>
108+
109+
This [example](./math500_gsm_bench.py) demonstrates how to evaluate and optimize LRMs using the KV-Cache Token Eviction algorithm. The example leverages [MATH500](https://huggingface.co/datasets/HuggingFaceH4/MATH-500) and [GSM8K](https://huggingface.co/datasets/openai/gsm8k) datasets.
110+
MATH500 contains a subset of 500 problems from the [MATH](https://github.com/hendrycks/math) benchmark, originally introduced in OpenAI’s Let’s Verify Step by Step paper. The subset covers six domains: algebra, geometry, intermediate algebra, number theory, precalculus, and probability.
111+
GSM8K (Grade School Math 8K) is a dataset of 8,500 high-quality, linguistically diverse grade-school math word problems. While the problems are conceptually simple, they often require multi-step reasoning, making them challenging for state-of-the-art language models due to the high diversity of problems.
112+
113+
114+
### Run Example
115+
116+
```bash
117+
python math500_gsm_bench.py \
118+
--subset gsm \
119+
--model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
120+
--enable_eviction \
121+
--algorithm rkv \
122+
--granularity per_group \
123+
--intermediate_tokens 1024
124+
```
125+
This will automatically:
126+
127+
- Download the selected model and dataset
128+
- Apply token eviction during the decoding stage
129+
- Evaluate the model and report the score
130+
131+
</details>
Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
# Copyright (C) 2018-2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# This logic is largely copied from the
5+
# - https://github.com/microsoft/ProphetNet/tree/master/CRITIC
6+
# - https://github.com/openai/prm800k
7+
# - https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py
8+
# - https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py
9+
# - https://github.com/VITA-Group/SEAL/tree/main
10+
11+
import argparse
12+
import json
13+
import os
14+
import random
15+
import re
16+
from collections import Counter
17+
from contextlib import ExitStack
18+
19+
from datasets import load_dataset
20+
from tqdm import tqdm
21+
from transformers import AutoModelForCausalLM
22+
from transformers import AutoTokenizer
23+
24+
from utils import add_attention_args, add_token_eviction_args
25+
from utils import get_eviction_patcher, get_sparse_attention_patcher
26+
27+
from reasoning_parser import extract_answer
28+
from reasoning_parser import parallel_math_equal
29+
from reasoning_parser import strip_string
30+
31+
# disable tokenizer parallelism warnings
32+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
33+
OUTPUT_LENGTHS = []
34+
35+
36+
def run_evaluation(res_path, save=False, k=None, output_dir=None):
37+
with open(res_path) as f:
38+
lines = f.readlines()
39+
data = [json.loads(line) for line in lines]
40+
41+
for example in tqdm(data):
42+
if "model_generation" not in example:
43+
example["model_generation"] = example["model_output"]
44+
if k is not None:
45+
example["model_generation"] = example["model_generation"][:k]
46+
gt_cot = example["answer"]
47+
gt_ans = extract_answer(gt_cot, data_name="omni-math")
48+
gt_cot = str(gt_cot).strip()
49+
gt_ans = strip_string(gt_ans, skip_unit=False)
50+
all_pred = [extract_answer(p, data_name="omni-math") for p in example["model_generation"]]
51+
all_pred = [strip_string(p, skip_unit=False) for p in all_pred]
52+
all_eval = parallel_math_equal(all_pred, gt_ans, timeout=5)
53+
effective_pred = [p for p, o in zip(all_pred, example["model_generation"]) if "boxed" in o]
54+
if len(effective_pred) == 0:
55+
effective_pred = all_pred
56+
counter = Counter(effective_pred)
57+
pred = counter.most_common(1)[0][0]
58+
index = all_pred.index(pred)
59+
eval = all_eval[index]
60+
example["all_pred"] = all_pred
61+
example["all_eval"] = all_eval
62+
example["mv_pred"] = pred
63+
example["mv_eval"] = eval
64+
example["mv_index"] = index
65+
66+
acc = sum([example["mv_eval"] for example in data]) / len(data)
67+
print(f"Accuracy: {acc:.3f}")
68+
69+
correct_avg_len = []
70+
incorrect_avg_len = []
71+
72+
for i, example in enumerate(data):
73+
if example["mv_eval"]:
74+
correct_avg_len.append(OUTPUT_LENGTHS[i])
75+
else:
76+
incorrect_avg_len.append(OUTPUT_LENGTHS[i])
77+
78+
if len(correct_avg_len) != 0:
79+
print(f"Correct avg len: {sum(correct_avg_len) / len(correct_avg_len):.2f}", end=", ")
80+
if len(incorrect_avg_len) != 0:
81+
print(f"Incorrect avg len: {sum(incorrect_avg_len) / len(incorrect_avg_len):.2f}")
82+
83+
if save:
84+
out_file = os.path.join(output_dir, "math_eval.jsonl")
85+
with open(out_file, "w") as f:
86+
for example in data:
87+
f.write(json.dumps(example) + "\n")
88+
89+
metric_file = os.path.join(output_dir, "metrics.json")
90+
with open(metric_file, "w") as f:
91+
json.dump({"acc": acc}, f)
92+
93+
94+
def trim_output(output):
95+
instruction_prefix = "Answer the following question"
96+
question_prefix = "Question:"
97+
comment_prefix = "Comment:" # for some reason, Llama 13B likes to generate these comments indefinitely
98+
99+
for prefix in [instruction_prefix, question_prefix, comment_prefix]:
100+
if prefix in output:
101+
output = output.split(prefix)[0]
102+
103+
return output
104+
105+
106+
def extract_box(pred_str):
107+
ans = pred_str.split("boxed")[-1]
108+
if len(ans) == 0:
109+
return ""
110+
elif ans[0] == "{":
111+
stack = 1
112+
a = ""
113+
for c in ans[1:]:
114+
if c == "{":
115+
stack += 1
116+
a += c
117+
elif c == "}":
118+
stack -= 1
119+
if stack == 0:
120+
break
121+
a += c
122+
else:
123+
a += c
124+
else:
125+
a = ans.split("$")[0].strip()
126+
127+
return a
128+
129+
130+
def prepare_dataset(dataset, max_samples=None):
131+
test_data = []
132+
if dataset == "MATH500":
133+
data = load_dataset("HuggingFaceH4/MATH-500", split="test")
134+
for example in data:
135+
gt = extract_box(example["solution"])
136+
test_data.append(
137+
{
138+
"question": example["problem"],
139+
"answer": example["solution"],
140+
"gt": gt,
141+
}
142+
)
143+
elif dataset == "GSM":
144+
data_path = "data/gsm/test.jsonl"
145+
with open(data_path) as fin:
146+
for line in fin:
147+
example = json.loads(line)
148+
answer = example["answer"].split("####")[1].strip()
149+
answer = re.sub(r"(\d),(\d)", r"\1\2", answer)
150+
test_data.append(
151+
{
152+
"question": example["question"],
153+
"answer": example["answer"].split("####")[0].strip(),
154+
"gt": answer,
155+
}
156+
)
157+
158+
if max_samples and len(test_data) > max_samples:
159+
test_data = test_data[:max_samples]
160+
161+
return test_data
162+
163+
164+
def main(args):
165+
random.seed(42)
166+
167+
test_data = prepare_dataset(args.dataset, max_samples=args.max_examples)
168+
169+
tokenizer = AutoTokenizer.from_pretrained(args.model)
170+
# set pad token to eos token if pad token is not set (as is the case for llama models)
171+
if tokenizer.pad_token is None:
172+
tokenizer.pad_token = tokenizer.eos_token
173+
tokenizer.pad_token_id = tokenizer.eos_token_id
174+
175+
contexts = []
176+
if args.use_custom_attention:
177+
sparse_attn = get_sparse_attention_patcher(args)
178+
contexts.append(sparse_attn)
179+
180+
if args.enable_eviction:
181+
token_eviction = get_eviction_patcher(args)
182+
contexts.append(token_eviction)
183+
184+
prefix = (
185+
"Answer the following questions. You should think step-by-step and put your final answer within \\boxed{}.\n"
186+
)
187+
prompts = []
188+
for example in test_data:
189+
prompt = prefix + "Question: " + example["question"].strip() + "\nAnswer: "
190+
if args.use_chat_format:
191+
if "deepseek" in args.model:
192+
messages = [{"role": "user", "content": prefix + "Question: " + example["question"].strip()}]
193+
else:
194+
messages = [
195+
{"role": "system", "content": prefix},
196+
{"role": "user", "content": "Question: " + example["question"].strip()},
197+
]
198+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
199+
if args.remove_bos and tokenizer.bos_token is not None and prompt.startswith(tokenizer.bos_token):
200+
prompt = prompt[len(tokenizer.bos_token) :]
201+
prompts.append(prompt)
202+
203+
kwargs = {"temperature": None, "top_p": None, "top_k": None}
204+
# force attn_implementation="eager" when using token eviction without custom attention
205+
if args.enable_eviction and not args.use_custom_attention:
206+
kwargs["attn_implementation"] = "eager"
207+
208+
model = AutoModelForCausalLM.from_pretrained(
209+
args.model,
210+
trust_remote_code=True,
211+
device_map="auto",
212+
token=os.environ.get("HF_TOKEN", None),
213+
**kwargs
214+
)
215+
model.eval()
216+
217+
contexts = []
218+
if args.use_custom_attention:
219+
sparse_attn = get_sparse_attention_patcher(args)
220+
contexts.append(sparse_attn)
221+
222+
if args.enable_eviction:
223+
token_eviction = get_eviction_patcher(args)
224+
contexts.append(token_eviction)
225+
226+
outputs = []
227+
prompts_with_eviction = 0
228+
avg_prompt_len = []
229+
with ExitStack() as stack:
230+
for ctx in contexts:
231+
if ctx is not None:
232+
stack.enter_context(ctx(model))
233+
234+
for prompt in prompts:
235+
tokenized_batch = tokenizer(prompt, return_tensors="pt", padding=True)
236+
tokenized_batch = {k: v.to(model.device) for k, v in tokenized_batch.items()}
237+
avg_prompt_len.append(tokenized_batch["input_ids"].shape[1])
238+
239+
output = model.generate(
240+
**tokenized_batch,
241+
do_sample=False,
242+
max_new_tokens=args.max_tokens,
243+
use_cache=True,
244+
pad_token_id=tokenizer.eos_token_id,
245+
)
246+
OUTPUT_LENGTHS.append(output.shape[1])
247+
if output.shape[1] > token_eviction.max_cache_size:
248+
prompts_with_eviction += 1
249+
output = [tokenizer.decode(o[avg_prompt_len[-1]:], skip_special_tokens=True) for o in output]
250+
outputs.extend(output)
251+
252+
outputs = [[trim_output(o)] for o in outputs]
253+
print(f"Average prompt length: {sum(avg_prompt_len) / len(avg_prompt_len):.2f}")
254+
print(f"Average length: {sum(OUTPUT_LENGTHS) / len(OUTPUT_LENGTHS):.2f}")
255+
print(f"Prompts with eviction: {prompts_with_eviction}/{len(OUTPUT_LENGTHS)}")
256+
257+
predictions = [
258+
{
259+
"prompt": prompt,
260+
"problem": example["question"],
261+
"answer": example["gt"],
262+
"solution": example["answer"],
263+
"model_generation": output,
264+
}
265+
for example, output, prompt in zip(test_data, outputs, prompts)
266+
]
267+
268+
with open(os.path.join(args.save_dir, "predictions.jsonl"), "w") as fout:
269+
for prediction in predictions:
270+
fout.write(json.dumps(prediction) + "\n")
271+
272+
273+
if __name__ == "__main__":
274+
parser = argparse.ArgumentParser()
275+
parser.add_argument("--model", type=str, required=True)
276+
parser.add_argument("--dataset", type=str, default="MATH500", choices=["MATH500", "GSM"])
277+
parser.add_argument("--max_examples", type=int, default=None)
278+
parser.add_argument("--start", type=int, default=None)
279+
parser.add_argument("--save_dir", type=str, default="results")
280+
parser.add_argument("--use_chat_format", action="store_true")
281+
parser.add_argument("--max_tokens", type=int, default=512)
282+
parser.add_argument("--remove_bos", action="store_true", default=True)
283+
284+
add_attention_args(parser)
285+
add_token_eviction_args(parser)
286+
args = parser.parse_args()
287+
288+
args.save_dir = os.path.join(args.save_dir, args.dataset)
289+
if args.remove_bos:
290+
args.save_dir = args.save_dir + "_remove_bos"
291+
292+
if args.max_examples or args.start:
293+
start = 0 if args.start is None else args.start
294+
end = start + args.max_examples if args.max_examples is not None else -1
295+
args.save_dir = os.path.join(args.save_dir, f"{start}_{end}")
296+
297+
if not os.path.exists(args.save_dir):
298+
os.makedirs(args.save_dir)
299+
300+
print(f"Results will be saved to {args.save_dir}")
301+
main(args)
302+
run_evaluation(os.path.join(args.save_dir, "predictions.jsonl"), output_dir=args.save_dir)

0 commit comments

Comments
 (0)