Skip to content

Commit 59d00f4

Browse files
authored
[feat] cherry-pick to 0.2.0-release to add rerope (#614)
1 parent 0101665 commit 59d00f4

File tree

13 files changed

+3811
-5
lines changed

13 files changed

+3811
-5
lines changed
72.6 KB
Loading

docs/source/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ user-guide/prefix-cache/index
5757
user-guide/sparse-attention/index
5858
user-guide/pd-disaggregation/index
5959
user-guide/metrics/metrics
60+
user-guide/rerope/rerope
6061
:::
6162

6263
:::{toctree}
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Rectified Rotary Position Embeddings
2+
3+
Using Rectified Rotary Position Embeddings (ReRoPE), we can more effectively extend the context length of LLM without the need for fine-tuning. This is about the Triton implementation of ReRoPE and its integration into the vLLM inference framework.
4+
5+
<div align="center">
6+
7+
**🚀 ReRoPE | 📄 blog [https://kexue.fm/archives/9708] [https://normxu.github.io/Rethinking-Rotary-Position-Embedding-3]**
8+
9+
10+
[![License](https://img.shields.io/badge/License-MIT-green.svg)](https://github.com/ModelEngine-Group/unified-cache-management/blob/main/LICENSE)
11+
[![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)](https://python.org)
12+
13+
</div>
14+
15+
## 🌟 What is ReRoPE?
16+
17+
<div align="center">
18+
19+
<img src="https://raw.githubusercontent.com/bojone/rerope/main/idea.png" width=750>
20+
21+
</div>
22+
23+
This approach combines direct extrapolation with position interpolation. A window size $w$ is established, where a position interval of $1$ is used within the window, and an interval of $\frac{1}{k}$ is applied outside. As $k \to \infty$, this simplifies to the form illustrated above. Under this scheme, the position encoding range never exceeds $w$ regardless of input length, potentially enabling support for arbitrarily long contexts.
24+
25+
The attention score calculation formulas are as follows,
26+
27+
$$
28+
\begin{aligned}
29+
score_{ij}^{1} &= (q_iR_i)(k_jR_j)^T, && i-j<w \\
30+
score_{ij}^{2} &= (q_iR_w)(k_j)^T, && i-j\ge w
31+
\end{aligned}
32+
$$
33+
34+
ReRoPE extends context length effectively but requires double attention—local within w and global compressed—significantly reducing throughput. Despite this overhead, it remains valuable for training-free long contexts, especially when combined with local attention windows to balance efficiency.
35+
36+
## 🧠 Triton ReRoPE Implementation
37+
38+
- Load Data
39+
40+
Compared to the triton rope implementation, data loading requires passing query2 with alternative rotary embedding position and unrotated key2.
41+
42+
- Construct ReRoPE Mask
43+
44+
During attention computation, the selection between attention score paths depends on the relative distance between query and key, necessitating construction of a rerope mask.
45+
46+
## 🏆 Results
47+
48+
<div align="center">
49+
50+
### The Experiment Results
51+
![ReRoPE Results](../../_static/images/rerope_performace.png)
52+
53+
The experiment is based on a hybrid Transformer-GAU (Gated Attention Unit) model with a size of 100M parameters. $logn$ indicates we add the scale factor $log n$⁡ at pretraining stage; $log n^{*}$ denotes we apply the scale factor to the attention matrix only for text exceeding the max sequence length, without any pretraining; $w256$ denotes the rerope windopw $w=256$.
54+
55+
</div>
56+
57+
## 🚀 Quick Start
58+
59+
### Installation
60+
61+
For installation instructions, please refer to the UCM's top-level README. Once UCM is installed, ReRoPE is naturally supported by running the following example python scripts.
62+
63+
```python
64+
export VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1
65+
export VLLM_USE_REROPE=true
66+
export DATA_DIR=/home/data/kv_cache
67+
export MODEL_PATH=/home/models/Qwen2.5-14B-Instruct
68+
export REROPE_WINDOW=32768
69+
export TRAINING_LENGTH=32768
70+
71+
python examples/offline_inference_rerope.py
72+
```
73+
74+
### Basic Usage
75+
76+
We need to modify the max_position_embeddings of the model according to the input length of prompts, as shown below.
77+
78+
```python
79+
llm_args = EngineArgs(
80+
model=model,
81+
kv_transfer_config=ktc,
82+
hf_overrides={
83+
"max_position_embeddings": 327680,
84+
},
85+
gpu_memory_utilization=0.9,
86+
max_num_batched_tokens=8192,
87+
block_size=16,
88+
enforce_eager=True,
89+
tensor_parallel_size=2,
90+
)
91+
```
92+
93+
## 📊 Supported Models
94+
95+
Qwen-based models now are available
96+
97+
98+
## 🎓 Cite
99+
100+
```
101+
@misc{rerope2023,
102+
title={Rectified Rotary Position Embeddings},
103+
author={Jianlin Su},
104+
year={2023},
105+
howpublished={\url{https://github.com/bojone/rerope}},
106+
}
107+
```

docs/source/user-guide/sparse-attention/cacheblend.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
# CacheBlend: : Fast Large Language Model Serving for RAG with Cached Knowledge Fusion
1+
# CacheBlend: Fast Large Language Model Serving for RAG with Cached Knowledge Fusion
22
<div align="center">
33

44
![blend_scheme.jpg](../../_static/images/blend_scheme.jpg)
55

6-
**🚀 Knowledge Cached Fusion Algorithm | 📄 EuroSys 2025 Paper **
6+
**🚀 Knowledge Cached Fusion Algorithm | 📄 EuroSys 2025 Paper**
77

88
[![License](https://img.shields.io/badge/License-MIT-green.svg)](https://github.com/ModelEngine-Group/unified-cache-management/blob/main/LICENSE)
99
[![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)](https://python.org)
@@ -31,7 +31,7 @@ CacheBlend reduces TTFT by 2.2 ~ 3.3× and increases throughput by 2.8 ~ 5× und
3131
1. **🔐 Chunk Hash Encoding**: Similar as prefix hash encoder, hash all blocks in each chunk from the same hash meta beginning.
3232
2. **⚡ Combine Prefix Cache and Chunk Cache**: Since chunk cache and native prefix cache share the same hash space, ucm first performs prefix cache lookup to fetch fully reused cache and then conduct chunk cache lookup to fetch the candidate cache for blending.
3333
3. **🎯 Delta-Rope PostProcess**: Rectify loaded chunk cache according to their position in the new request.
34-
3. **🔍 Integrate Cache Blend and First Token Generation**: Construct compute mask and attention meta according to HKVD tokens, cache miss tokens and suffix tokens, then compute their kv cache in a single model forward stage.
34+
3. **🔍 Integrate Cache Blend and First Token Generation**: Construct compute mask and attention meta according to the HKVD tokens, cache miss tokens and suffix tokens, then compute their kv cache in a single model forward stage.
3535
4. **🚀 Comprehensive Hook for LLM Forward Pipeline**: Based on ucm sparse module, blend module sparse the prefill tokens not only in attention stage but also in ffn, layer stage.
3636

3737
## 🚀 Quick Start
@@ -104,6 +104,6 @@ Llama-based models and Qwen-based models now are available
104104

105105
<div align="center">
106106

107-
**🌟 Star [UCM](https://github.com/ModelEngine-Group/unified-cache-management) repository if you find KvComp useful!**
107+
**🌟 Star [UCM](https://github.com/ModelEngine-Group/unified-cache-management) repository if you find CacheBlend useful!**
108108

109109
</div>

docs/source/user-guide/sparse-attention/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,5 @@ esa
4141
gsa
4242
kvcomp
4343
kvstar
44+
cacheblend
4445
:::

examples/offline_inference_blend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def main():
186186
# choose one data row in LongBenchV1 (wikimqa)
187187
assert os.path.isfile(
188188
path_to_dataset
189-
), f"Incorrect dataset path. Please specify the dataset path by `export DATASET_PATH=/path/to/longbench/multifieldqa_zh.jsonl`"
189+
), f"Incorrect dataset path. Please specify the dataset path by `export DATASET_PATH=/home/data/Longbench/data/2wikimqa.jsonl`"
190190
with open(path_to_dataset, "r") as f:
191191
lines = f.readlines()
192192
dataset_row = json.loads(lines[0])
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import contextlib
2+
import json
3+
import os
4+
import sys
5+
import time
6+
from dataclasses import asdict
7+
8+
from transformers import AutoTokenizer
9+
10+
# setting for rerope
11+
os.environ["VLLM_USE_REROPE"] = "true"
12+
13+
# Third Party
14+
from vllm import LLM, SamplingParams
15+
from vllm.config import KVTransferConfig
16+
from vllm.engine.arg_utils import EngineArgs
17+
18+
from ucm.logger import init_logger
19+
20+
logger = init_logger(__name__)
21+
22+
23+
def setup_environment_variables():
24+
os.environ["VLLM_USE_V1"] = "1"
25+
os.environ["PYTHONHASHSEED"] = "123456"
26+
27+
os.environ["VLLM_ATTENTION_BACKEND"] = "TRITON_ATTN_VLLM_V1"
28+
os.environ["REROPE_WINDOW"] = "32768"
29+
os.environ["TRAINING_LENGTH"] = "32768"
30+
31+
global data_dir
32+
data_dir = os.getenv("DATA_DIR", "/home/data/kv_cache")
33+
if not os.path.isdir(data_dir):
34+
data_dir = input(
35+
"Enter the directory for UCMStore to save kv cache, e.g. /home/data/kv_cache: "
36+
)
37+
create = input(f"Directory {data_dir} dose not exist. Create it? (Y/n): ")
38+
if create.lower() == "y":
39+
os.makedirs(data_dir, exist_ok=True)
40+
else:
41+
print("Exiting. Directory not created.")
42+
sys.exit(1)
43+
44+
45+
@contextlib.contextmanager
46+
def build_llm_with_uc(module_path: str, name: str, model: str):
47+
ktc = KVTransferConfig(
48+
kv_connector=name,
49+
kv_connector_module_path=module_path,
50+
kv_role="kv_both",
51+
kv_connector_extra_config={
52+
"ucm_connectors": [
53+
{
54+
"ucm_connector_name": "UcmNfsStore",
55+
"ucm_connector_config": {
56+
"storage_backends": data_dir,
57+
"use_direct": False,
58+
},
59+
}
60+
],
61+
},
62+
)
63+
64+
llm_args = EngineArgs(
65+
model=model,
66+
kv_transfer_config=ktc,
67+
hf_overrides={
68+
"max_position_embeddings": 327680,
69+
},
70+
gpu_memory_utilization=0.9,
71+
max_num_batched_tokens=8192,
72+
block_size=16,
73+
enforce_eager=True,
74+
tensor_parallel_size=2,
75+
)
76+
77+
llm = LLM(**asdict(llm_args))
78+
try:
79+
yield llm
80+
finally:
81+
logger.info("LLM engine is exiting.")
82+
83+
84+
def print_output(
85+
llm: LLM,
86+
prompt: list[str],
87+
sampling_params: SamplingParams,
88+
req_str: str,
89+
):
90+
start = time.time()
91+
outputs = llm.generate(prompt, sampling_params)
92+
print("-" * 50)
93+
for output in outputs:
94+
generated_text = output.outputs[0].text
95+
print(f"Generated text: {generated_text!r}")
96+
print(f"Generation took {time.time() - start:.2f} seconds, {req_str} request done.")
97+
print("-" * 50)
98+
99+
100+
def main():
101+
module_path = "ucm.integration.vllm.ucm_connector"
102+
name = "UCMConnector"
103+
model = os.getenv("MODEL_PATH", "/home/models/Qwen2.5-14B-Instruct")
104+
if not os.path.isdir(model):
105+
model = input("Enter path to model, e.g. /home/models/Qwen2.5-14B-Instruct: ")
106+
if not os.path.isdir(model):
107+
print("Exiting. Incorrect model_path")
108+
sys.exit(1)
109+
110+
tokenizer = AutoTokenizer.from_pretrained(model, use_chat_template=True)
111+
setup_environment_variables()
112+
113+
with build_llm_with_uc(module_path, name, model) as llm:
114+
115+
data_all = []
116+
path_to_dataset = os.getenv(
117+
"DATASET_PATH", "/home/data/Longbench/data/multifieldqa_zh.jsonl"
118+
)
119+
if not os.path.isfile(path_to_dataset):
120+
path_to_dataset = input(
121+
"Enter path to one of the longbench dataset, e.g. /home/data/Longbench/data/multifieldqa_zh.jsonl: "
122+
)
123+
if not os.path.isfile(path_to_dataset):
124+
print("Exiting. Incorrect dataset path")
125+
sys.exit(1)
126+
with open(path_to_dataset, "r", encoding="utf-8") as f:
127+
for line in f:
128+
data_all.append(json.loads(line))
129+
130+
materials = []
131+
questions = []
132+
references = []
133+
batch_size = 30
134+
num_batch = 2
135+
for idx in range(num_batch):
136+
data = data_all[idx * batch_size : (idx + 1) * batch_size]
137+
138+
materials.append(
139+
"\n\n".join(
140+
[
141+
f"【语料{i+1}\n{item.get('context', '')}"
142+
for i, item in enumerate(data)
143+
]
144+
)
145+
)
146+
questions.append(
147+
"\n".join(
148+
[
149+
f"{i+1}. {item.get('input', '')}"
150+
for i, item in enumerate(data[:15])
151+
]
152+
)
153+
)
154+
references.append(
155+
[
156+
f"{i+1}. {item.get('answers', '')}"
157+
for i, item in enumerate(data[:15])
158+
]
159+
)
160+
161+
system_prompt = "你是一个AI助手,请根据以下材料回答问题。"
162+
tokenized_inputs = []
163+
for material, question in zip(materials, questions):
164+
content = (
165+
"请根据以下文本内容回答后面的问题:\n\n"
166+
"【文本内容开始】\n"
167+
f"{material}\n"
168+
"【文本内容结束】\n\n"
169+
"请直接回答以下问题:\n"
170+
f"{question}"
171+
)
172+
173+
messages = [
174+
{"role": "system", "content": system_prompt},
175+
{"role": "user", "content": content},
176+
]
177+
inputs = tokenizer.apply_chat_template(
178+
messages,
179+
add_generation_prompt=True,
180+
tokenize=False,
181+
)
182+
tokenized_inputs.append(inputs)
183+
184+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=2048)
185+
186+
for req in range(num_batch):
187+
print_output(
188+
llm, tokenized_inputs[req], sampling_params, "request_" + str(req)
189+
)
190+
191+
192+
if __name__ == "__main__":
193+
main()

0 commit comments

Comments
 (0)