Skip to content

Commit 6899198

Browse files
authored
feature for triton rerope (#497)
1 parent ea76b36 commit 6899198

File tree

8 files changed

+3805
-0
lines changed

8 files changed

+3805
-0
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Rectified Rotary Position Embeddings (ReRoPE)
2+
3+
Using ReRoPE, we can more effectively extend the context length of LLM without the need for fine-tuning. This is about the Triton implementation of ReRoPE and its integration into the vLLM inference framework.
4+
5+
**🚀 ReRoPE | 📄 blog [https://kexue.fm/archives/9708] [https://normxu.github.io/Rethinking-Rotary-Position-Embedding-3]**
6+
7+
[![License](https://img.shields.io/badge/License-MIT-green.svg)](https://github.com/ModelEngine-Group/unified-cache-management/blob/main/LICENSE)
8+
[![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)](https://python.org)
9+
10+
11+
## 🌟 What is ReRoPE?
12+
13+
<img src="https://raw.githubusercontent.com/bojone/rerope/main/idea.png" width=750>
14+
15+
This approach combines direct extrapolation with position interpolation. A window size $w$ is established, where a position interval of $1$ is used within the window, and an interval of $\frac{1}{k}$ is applied outside. As $k \to \infty$, this simplifies to the form illustrated above. Under this scheme, the position encoding range never exceeds $w$ regardless of input length, potentially enabling support for arbitrarily long contexts.
16+
17+
The attention score calculation formulas are as follows,
18+
19+
$$
20+
\begin{align}
21+
score_{ij}^{1} &= (q_iR_i)(k_jR_j)^T, && i-j<w \\
22+
score_{ij}^{2} &= (q_iR_w)(k_j)^T, && i-j\ge w
23+
\end{align}
24+
$$
25+
26+
ReRoPE extends context length effectively but requires double attention—local within w and global compressed—significantly reducing throughput. Despite this overhead, it remains valuable for training-free long contexts, especially when combined with local attention windows to balance efficiency.
27+
28+
## 🧠 Triton ReRoPE Implementation
29+
30+
- Load Data
31+
32+
Compared to the triton rope implementation, data loading requires passing query2 with alternative rotary embedding position and unrotated key2.
33+
34+
- Construct ReRoPE Mask
35+
36+
During attention computation, the selection between attention score paths depends on the relative distance between query and key, necessitating construction of a rerope mask.
37+
38+
## 🏆 Results
39+
40+
![alt text](results.png)
41+
42+
## 🚀 Quick Start
43+
44+
### Installation
45+
46+
For installation instructions, please refer to the UCM's top-level README. Once UCM is installed, ReRoPE is naturally supported by running the following example python scripts.
47+
48+
```python
49+
export VLLM_ATTENTION_BACKEND = TRITON_ATTN_VLLM_V1
50+
export VLLM_USE_REROPE = true
51+
export DATA_DIR=/home/data/kv_cache
52+
export MODEL_PATH=/home/models/Qwen2.5-14B-Instruct
53+
export REROPE_WINDOW = 32768
54+
export TRAINING_LENGTH = 32768
55+
56+
python examples/offline_inference_rerope.py
57+
```
58+
59+
### Basic Usage
60+
61+
We need to modify the max_position_embeddings of the model according to the input length of prompts, as shown below.
62+
63+
```python
64+
llm_args = EngineArgs(
65+
model=model,
66+
kv_transfer_config=ktc,
67+
hf_overrides={
68+
"max_position_embeddings": 327680,
69+
},
70+
gpu_memory_utilization=0.9,
71+
max_num_batched_tokens=8192,
72+
block_size=16,
73+
enforce_eager=True,
74+
tensor_parallel_size=2,
75+
)
76+
```
77+
78+
## 📊 Supported Models
79+
80+
Qwen-based models now are available
81+
82+
83+
## 🎓 Cite
84+
85+
```
86+
@misc{rerope2023,
87+
title={Rectified Rotary Position Embeddings},
88+
author={Jianlin Su},
89+
year={2023},
90+
howpublished={\url{https://github.com/bojone/rerope}},
91+
}
92+
```
72.6 KB
Loading
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import contextlib
2+
import json
3+
import os
4+
import sys
5+
import time
6+
from dataclasses import asdict
7+
8+
from transformers import AutoTokenizer
9+
10+
# setting for rerope
11+
os.environ["VLLM_USE_REROPE"] = "true"
12+
13+
# Third Party
14+
from vllm import LLM, SamplingParams
15+
from vllm.config import KVTransferConfig
16+
from vllm.engine.arg_utils import EngineArgs
17+
18+
from ucm.logger import init_logger
19+
20+
logger = init_logger(__name__)
21+
22+
23+
def setup_environment_variables():
24+
os.environ["VLLM_USE_V1"] = "1"
25+
os.environ["PYTHONHASHSEED"] = "123456"
26+
27+
os.environ["VLLM_ATTENTION_BACKEND"] = "TRITON_ATTN_VLLM_V1"
28+
os.environ["REROPE_WINDOW"] = "32768"
29+
os.environ["TRAINING_LENGTH"] = "32768"
30+
31+
global data_dir
32+
data_dir = os.getenv("DATA_DIR", "/home/data/kv_cache")
33+
if not os.path.isdir(data_dir):
34+
data_dir = input(
35+
"Enter the directory for UCMStore to save kv cache, e.g. /home/data/kv_cache: "
36+
)
37+
create = input(f"Directory {data_dir} dose not exist. Create it? (Y/n): ")
38+
if create.lower() == "y":
39+
os.makedirs(data_dir, exist_ok=True)
40+
else:
41+
print("Exiting. Directory not created.")
42+
sys.exit(1)
43+
44+
45+
@contextlib.contextmanager
46+
def build_llm_with_uc(module_path: str, name: str, model: str):
47+
ktc = KVTransferConfig(
48+
kv_connector=name,
49+
kv_connector_module_path=module_path,
50+
kv_role="kv_both",
51+
kv_connector_extra_config={
52+
"ucm_connectors": [
53+
{
54+
"ucm_connector_name": "UcmNfsStore",
55+
"ucm_connector_config": {
56+
"storage_backends": data_dir,
57+
"use_direct": False,
58+
},
59+
}
60+
],
61+
},
62+
)
63+
64+
llm_args = EngineArgs(
65+
model=model,
66+
kv_transfer_config=ktc,
67+
hf_overrides={
68+
"max_position_embeddings": 327680,
69+
},
70+
gpu_memory_utilization=0.9,
71+
max_num_batched_tokens=8192,
72+
block_size=16,
73+
enforce_eager=True,
74+
tensor_parallel_size=2,
75+
)
76+
77+
llm = LLM(**asdict(llm_args))
78+
try:
79+
yield llm
80+
finally:
81+
logger.info("LLM engine is exiting.")
82+
83+
84+
def print_output(
85+
llm: LLM,
86+
prompt: list[str],
87+
sampling_params: SamplingParams,
88+
req_str: str,
89+
):
90+
start = time.time()
91+
outputs = llm.generate(prompt, sampling_params)
92+
print("-" * 50)
93+
for output in outputs:
94+
generated_text = output.outputs[0].text
95+
print(f"Generated text: {generated_text!r}")
96+
print(f"Generation took {time.time() - start:.2f} seconds, {req_str} request done.")
97+
print("-" * 50)
98+
99+
100+
def main():
101+
module_path = "ucm.integration.vllm.ucm_connector"
102+
name = "UCMConnector"
103+
model = os.getenv("MODEL_PATH", "/home/models/Qwen2.5-14B-Instruct")
104+
if not os.path.isdir(model):
105+
model = input("Enter path to model, e.g. /home/models/Qwen2.5-14B-Instruct: ")
106+
if not os.path.isdir(model):
107+
print("Exiting. Incorrect model_path")
108+
sys.exit(1)
109+
110+
tokenizer = AutoTokenizer.from_pretrained(model, use_chat_template=True)
111+
setup_environment_variables()
112+
113+
with build_llm_with_uc(module_path, name, model) as llm:
114+
115+
data_all = []
116+
path_to_dataset = os.getenv(
117+
"DATASET_PATH", "/home/data/Longbench/data/multifieldqa_zh.jsonl"
118+
)
119+
if not os.path.isfile(path_to_dataset):
120+
path_to_dataset = input(
121+
"Enter path to one of the longbench dataset, e.g. /home/data/Longbench/data/multifieldqa_zh.jsonl: "
122+
)
123+
if not os.path.isfile(path_to_dataset):
124+
print("Exiting. Incorrect dataset path")
125+
sys.exit(1)
126+
with open(path_to_dataset, "r", encoding="utf-8") as f:
127+
for line in f:
128+
data_all.append(json.loads(line))
129+
130+
materials = []
131+
questions = []
132+
references = []
133+
batch_size = 30
134+
num_batch = 2
135+
for idx in range(num_batch):
136+
data = data_all[idx * batch_size : (idx + 1) * batch_size]
137+
138+
materials.append(
139+
"\n\n".join(
140+
[
141+
f"【语料{i+1}\n{item.get('context', '')}"
142+
for i, item in enumerate(data)
143+
]
144+
)
145+
)
146+
questions.append(
147+
"\n".join(
148+
[
149+
f"{i+1}. {item.get('input', '')}"
150+
for i, item in enumerate(data[:15])
151+
]
152+
)
153+
)
154+
references.append(
155+
[
156+
f"{i+1}. {item.get('answers', '')}"
157+
for i, item in enumerate(data[:15])
158+
]
159+
)
160+
161+
system_prompt = "你是一个AI助手,请根据以下材料回答问题。"
162+
tokenized_inputs = []
163+
for material, question in zip(materials, questions):
164+
content = (
165+
"请根据以下文本内容回答后面的问题:\n\n"
166+
"【文本内容开始】\n"
167+
f"{material}\n"
168+
"【文本内容结束】\n\n"
169+
"请直接回答以下问题:\n"
170+
f"{question}"
171+
)
172+
173+
messages = [
174+
{"role": "system", "content": system_prompt},
175+
{"role": "user", "content": content},
176+
]
177+
inputs = tokenizer.apply_chat_template(
178+
messages,
179+
add_generation_prompt=True,
180+
tokenize=False,
181+
)
182+
tokenized_inputs.append(inputs)
183+
184+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=2048)
185+
186+
for req in range(num_batch):
187+
print_output(
188+
llm, tokenized_inputs[req], sampling_params, "request_" + str(req)
189+
)
190+
191+
192+
if __name__ == "__main__":
193+
main()

0 commit comments

Comments
 (0)