Skip to content

Commit 0d51c55

Browse files
committed
[GenAI] Support Token Eviction for LRMs
1 parent 6d10335 commit 0d51c55

File tree

6 files changed

+163
-65
lines changed

6 files changed

+163
-65
lines changed

modules/genai_optimizations/README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ This module provides experimental optimizations for GenAI models in PyTorch. The
66

77
- Text Generation Using LLMs
88
- Visual language text generation
9+
- Reasoning and Problem Solving
910

1011
## Supported Generative AI Optimization Methods
1112

@@ -34,6 +35,14 @@ This module provides experimental optimizations for GenAI models in PyTorch. The
3435
Paper: https://arxiv.org/pdf/2306.14048
3536
- **SnapKV Mode** – Modifies the *H2O* approach by computing token importance within a small sliding window of the most recent queries during the prefill stage, then reverting to the H2O strategy during decoding. The authors observed that only a small subset of prompt tokens is sufficient for accurate response generation.
3637
Paper: https://arxiv.org/pdf/2404.14469
38+
- **RKV Mode** - Computes token importance scores based on attention weights over a sliding window of the most recent queries during both the prefill and decode stages. Importance scores are stabilized using per-token max-pooling and then averaged across attention heads.
39+
40+
Refined modes enhance standard eviction strategies by selecting the most representative tokens or blocks from the evictable (intermediate) region. These methods aim to balance contextual importance with redundancy reduction to optimize cache efficiency. If `refined_algorithm` is enabled but `refined_tokens` is not specified or set to 0, the number of refined tokens is determined dynamically as part of the intermediate token budget. Budget for primary algorithm is allocated by selecting the minimal number of tokens or groups that together capture at least 90% of the total attention mass, ensuring that all high-importance tokens are retained. For the remaining eviction budget, each token’s dissimilarity is computed relative to the already retained set, promoting information diversity and reducing redundancy.
41+
42+
Supported refined modes:
43+
- **KVCrush Mode** - Selects representative blocks based on diversity rather than raw importance. This is achieved by generating binary indicators for each token, constructing an anchor point (reference pattern) using one of several modes: `random`, `zeros`, `ones`, `mean`, `alternate`, and selecting blocks with the highest Hamming distance to the anchor point.
44+
Paper: https://arxiv.org/pdf/2503.00022
45+
- **DiverseKV Mode** – Implements a dynamic redundancy scoring mechanism to identify and de-prioritize repetitive tokens based on cosine similarity of key vectors with already retained tokens. Key vectors are normalized, and cosine similarities are computed with diagonal values zeroed to avoid self-similarity. Similarities are thresholded on a per-head basis—only values greater than or equal to the mean similarity for each head are kept and then aggregated across heads. For the remaining eviction budget, each token or group's dissimilarity to already retained tokens or groups is calculated. Tokens/groups with the highest dissimilarity scores are retained, maximizing contextual diversity while reducing redundancy.
3746

3847
## Supported and tested models
3948

@@ -53,6 +62,12 @@ Multimodal Large Language Models:
5362
- [Qwen/Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
5463
- [Qwen/Qwen2-VL-2B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct)
5564

65+
Large Reasoning Models:
66+
67+
- [deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B)
68+
- [Qwen/Qwen3-1.7B](https://huggingface.co/Qwen/Qwen3-1.7B)
69+
- [microsoft/Phi-4-mini-reasoning](https://huggingface.co/microsoft/Phi-4-mini-reasoning)
70+
5671
## Prerequisites
5772

5873
Before running algorithms, ensure you have **Python 3.10+** installed and set up your environment.

modules/genai_optimizations/benchmarks/README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,14 @@ GSM8K (Grade School Math 8K) is a dataset of 8,500 high-quality, linguistically
115115

116116
```bash
117117
python math500_gsm_bench.py \
118-
--subset gsm \
118+
--dataset MATH500 \
119119
--model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
120+
--max_tokens 5000 \
121+
--max_examples 100 \
120122
--enable_eviction \
121123
--algorithm rkv \
122124
--granularity per_group \
123-
--intermediate_tokens 1024
125+
--intermediate_tokens 512
124126
```
125127
This will automatically:
126128

modules/genai_optimizations/benchmarks/math500_gsm_bench.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,17 @@ def prepare_dataset(dataset, max_samples=None):
141141
}
142142
)
143143
elif dataset == "GSM":
144-
data_path = "data/gsm/test.jsonl"
144+
data_path = "gsm.jsonl"
145+
146+
if not os.path.exists(data_path):
147+
import requests
148+
url = "https://raw.githubusercontent.com/VITA-Group/SEAL/main/data/gsm/test.jsonl"
149+
response = requests.get(url)
150+
response.raise_for_status()
151+
with open(data_path, "w", encoding="utf-8") as f:
152+
f.write(response.text)
153+
print(f"Downloaded and saved to '{data_path}'.")
154+
145155
with open(data_path) as fin:
146156
for line in fin:
147157
example = json.loads(line)
@@ -187,7 +197,7 @@ def main(args):
187197
prompts = []
188198
for example in test_data:
189199
prompt = prefix + "Question: " + example["question"].strip() + "\nAnswer: "
190-
if args.use_chat_format:
200+
if not args.omit_chat_template:
191201
if "deepseek" in args.model:
192202
messages = [{"role": "user", "content": prefix + "Question: " + example["question"].strip()}]
193203
else:
@@ -196,7 +206,7 @@ def main(args):
196206
{"role": "user", "content": "Question: " + example["question"].strip()},
197207
]
198208
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
199-
if args.remove_bos and tokenizer.bos_token is not None and prompt.startswith(tokenizer.bos_token):
209+
if not args.keep_bos and tokenizer.bos_token is not None and prompt.startswith(tokenizer.bos_token):
200210
prompt = prompt[len(tokenizer.bos_token) :]
201211
prompts.append(prompt)
202212

@@ -224,35 +234,31 @@ def main(args):
224234
contexts.append(token_eviction)
225235

226236
outputs = []
227-
prompts_with_eviction = 0
228237
avg_prompt_len = []
229238
with ExitStack() as stack:
230239
for ctx in contexts:
231240
if ctx is not None:
232241
stack.enter_context(ctx(model))
233242

234-
for prompt in prompts:
235-
tokenized_batch = tokenizer(prompt, return_tensors="pt", padding=True)
236-
tokenized_batch = {k: v.to(model.device) for k, v in tokenized_batch.items()}
237-
avg_prompt_len.append(tokenized_batch["input_ids"].shape[1])
238-
239-
output = model.generate(
240-
**tokenized_batch,
241-
do_sample=False,
242-
max_new_tokens=args.max_tokens,
243-
use_cache=True,
244-
pad_token_id=tokenizer.eos_token_id,
245-
)
246-
OUTPUT_LENGTHS.append(output.shape[1])
247-
if output.shape[1] > token_eviction.max_cache_size:
248-
prompts_with_eviction += 1
249-
output = [tokenizer.decode(o[avg_prompt_len[-1]:], skip_special_tokens=True) for o in output]
250-
outputs.extend(output)
243+
for prompt in tqdm(prompts):
244+
tokenized_batch = tokenizer(prompt, return_tensors="pt", padding=True)
245+
tokenized_batch = {k: v.to(model.device) for k, v in tokenized_batch.items()}
246+
avg_prompt_len.append(tokenized_batch["input_ids"].shape[1])
247+
248+
output = model.generate(
249+
**tokenized_batch,
250+
do_sample=False,
251+
max_new_tokens=args.max_tokens,
252+
use_cache=True,
253+
pad_token_id=tokenizer.eos_token_id,
254+
)
255+
OUTPUT_LENGTHS.append(output.shape[1])
256+
output = [tokenizer.decode(o[avg_prompt_len[-1]:], skip_special_tokens=True) for o in output]
257+
outputs.extend(output)
251258

252259
outputs = [[trim_output(o)] for o in outputs]
253260
print(f"Average prompt length: {sum(avg_prompt_len) / len(avg_prompt_len):.2f}")
254261
print(f"Average length: {sum(OUTPUT_LENGTHS) / len(OUTPUT_LENGTHS):.2f}")
255-
print(f"Prompts with eviction: {prompts_with_eviction}/{len(OUTPUT_LENGTHS)}")
256262

257263
predictions = [
258264
{
@@ -277,17 +283,17 @@ def main(args):
277283
parser.add_argument("--max_examples", type=int, default=None)
278284
parser.add_argument("--start", type=int, default=None)
279285
parser.add_argument("--save_dir", type=str, default="results")
280-
parser.add_argument("--use_chat_format", action="store_true")
281-
parser.add_argument("--max_tokens", type=int, default=512)
282-
parser.add_argument("--remove_bos", action="store_true", default=True)
286+
parser.add_argument("--max_tokens", type=int, default=5000)
287+
parser.add_argument("--omit_chat_template", action="store_true")
288+
parser.add_argument("--keep_bos", action="store_true")
283289

284290
add_attention_args(parser)
285291
add_token_eviction_args(parser)
286292
args = parser.parse_args()
287293

288294
args.save_dir = os.path.join(args.save_dir, args.dataset)
289-
if args.remove_bos:
290-
args.save_dir = args.save_dir + "_remove_bos"
295+
if args.keep_bos:
296+
args.save_dir = args.save_dir + "_keep_bos"
291297

292298
if args.max_examples or args.start:
293299
start = 0 if args.start is None else args.start

modules/genai_optimizations/benchmarks/reasoning_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def strip_string(string, skip_unit=False):
317317
string = string.replace("infinity", "\\infty")
318318
if "\\infty" not in string:
319319
string = string.replace("inf", "\\infty")
320-
string = string.replace("+\\inity", "\\infty")
320+
string = string.replace("\\inity", "\\infty")
321321

322322
# and
323323
string = string.replace("and", "")

modules/genai_optimizations/genai_opt/sparse_attention.py

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from transformers.cache_utils import Cache
1717
from transformers.models.llama.modeling_llama import repeat_kv
1818
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
19+
from transformers.models.phi3.modeling_phi3 import apply_rotary_pos_emb as phi3_apply_rotary_pos_emb
1920
from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb
2021

2122
from block_sparse_attn import block_sparse_attn_func
@@ -619,7 +620,7 @@ def qwen2_vl_forward(
619620
value_states=value_states,
620621
attention_mask=attention_mask,
621622
scaling=module.scaling,
622-
dropout_p=module.attention_dropout if module.training else 0.0,
623+
dropout=module.attention_dropout if module.training else 0.0,
623624
)
624625

625626
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
@@ -657,7 +658,91 @@ def llama_forward(
657658
key_states=key_states,
658659
value_states=value_states,
659660
attention_mask=attention_mask,
660-
dropout_p=module.attention_dropout if module.training else 0.0,
661+
dropout=module.attention_dropout if module.training else 0.0,
662+
scaling=module.scaling,
663+
)
664+
665+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
666+
attn_output = module.o_proj(attn_output)
667+
return attn_output, attn_weights
668+
669+
670+
def qwen3_forward(
671+
module,
672+
hidden_states: torch.Tensor,
673+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
674+
attention_mask: Optional[torch.Tensor],
675+
past_key_values: Optional[Cache] = None,
676+
cache_position: Optional[torch.LongTensor] = None,
677+
**kwargs,
678+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
679+
input_shape = hidden_states.shape[:-1]
680+
hidden_shape = (*input_shape, -1, module.head_dim)
681+
682+
query_states = module.q_norm(module.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
683+
key_states = module.k_norm(module.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
684+
value_states = module.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
685+
686+
cos, sin = position_embeddings
687+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
688+
689+
if past_key_values is not None:
690+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
691+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
692+
key_states, value_states = past_key_values.update(key_states, value_states, module.layer_idx, cache_kwargs)
693+
694+
attn_output, attn_weights = module.attn_interface(
695+
module,
696+
query_states=query_states,
697+
key_states=key_states,
698+
value_states=value_states,
699+
attention_mask=attention_mask,
700+
dropout=module.attention_dropout if module.training else 0.0,
701+
scaling=module.scaling,
702+
)
703+
704+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
705+
attn_output = module.o_proj(attn_output)
706+
return attn_output, attn_weights
707+
708+
709+
def phi_forward(
710+
module,
711+
hidden_states: torch.Tensor,
712+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
713+
attention_mask: Optional[torch.Tensor],
714+
past_key_values: Optional[Cache] = None,
715+
cache_position: Optional[torch.LongTensor] = None,
716+
**kwargs,
717+
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
718+
input_shape = hidden_states.shape[:-1]
719+
hidden_shape = (*input_shape, -1, module.head_dim)
720+
721+
qkv = module.qkv_proj(hidden_states)
722+
query_pos = module.config.num_attention_heads * module.head_dim
723+
query_states = qkv[..., :query_pos]
724+
key_states = qkv[..., query_pos : query_pos + module.num_key_value_heads * module.head_dim]
725+
value_states = qkv[..., query_pos + module.num_key_value_heads * module.head_dim :]
726+
727+
query_states = query_states.view(hidden_shape).transpose(1, 2)
728+
key_states = key_states.view(hidden_shape).transpose(1, 2)
729+
value_states = value_states.view(hidden_shape).transpose(1, 2)
730+
731+
cos, sin = position_embeddings
732+
query_states, key_states = phi3_apply_rotary_pos_emb(query_states, key_states, cos, sin)
733+
734+
if past_key_values is not None:
735+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
736+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
737+
key_states, value_states = past_key_values.update(key_states, value_states, module.layer_idx, cache_kwargs)
738+
739+
attn_output, attn_weights = module.attn_interface(
740+
module,
741+
query_states=query_states,
742+
key_states=key_states,
743+
value_states=value_states,
744+
attention_mask=attention_mask,
745+
dropout=module.attention_dropout if module.training else 0.0,
661746
scaling=module.scaling,
662747
)
663748

@@ -672,6 +757,8 @@ def llama_forward(
672757
"LlamaForCausalLM": llama_forward,
673758
"MistralForCausalLM": llama_forward,
674759
"Qwen2ForCausalLM": llama_forward,
760+
"Qwen3ForCausalLM": qwen3_forward,
761+
"Phi3ForCausalLM": phi_forward,
675762
}
676763

677764
def get_custom_attn_forward(model: PreTrainedModel):

0 commit comments

Comments
 (0)