You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
CacheBlend : Insert table of contents and modify comments
ReRoPE: Revise the format and export error of rerope documentation, and add web linking features
---------
Co-authored-by: wuhuxiao <[email protected]>
Copy file name to clipboardExpand all lines: docs/source/user-guide/rerope/rerope.md
+24-9Lines changed: 24 additions & 9 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,26 +1,34 @@
1
-
# Rectified Rotary Position Embeddings (ReRoPE)
1
+
# Rectified Rotary Position Embeddings
2
2
3
-
Using ReRoPE, we can more effectively extend the context length of LLM without the need for fine-tuning. This is about the Triton implementation of ReRoPE and its integration into the vLLM inference framework.
3
+
Using Rectified Rotary Position Embeddings (ReRoPE), we can more effectively extend the context length of LLM without the need for fine-tuning. This is about the Triton implementation of ReRoPE and its integration into the vLLM inference framework.
4
+
5
+
<divalign="center">
4
6
5
7
**🚀 ReRoPE | 📄 blog [https://kexue.fm/archives/9708][https://normxu.github.io/Rethinking-Rotary-Position-Embedding-3]**
This approach combines direct extrapolation with position interpolation. A window size $w$ is established, where a position interval of $1$ is used within the window, and an interval of $\frac{1}{k}$ is applied outside. As $k \to \infty$, this simplifies to the form illustrated above. Under this scheme, the position encoding range never exceeds $w$ regardless of input length, potentially enabling support for arbitrarily long contexts.
16
24
17
25
The attention score calculation formulas are as follows,
18
26
19
27
$$
20
-
\begin{align}
28
+
\begin{aligned}
21
29
score_{ij}^{1} &= (q_iR_i)(k_jR_j)^T, && i-j<w \\
22
30
score_{ij}^{2} &= (q_iR_w)(k_j)^T, && i-j\ge w
23
-
\end{align}
31
+
\end{aligned}
24
32
$$
25
33
26
34
ReRoPE extends context length effectively but requires double attention—local within w and global compressed—significantly reducing throughput. Despite this overhead, it remains valuable for training-free long contexts, especially when combined with local attention windows to balance efficiency.
The experiment is based on a hybrid Transformer-GAU (Gated Attention Unit) model with a size of 100M parameters. $logn$ indicates we add the scale factor $log n$ at pretraining stage; $log n^{*}$ denotes we apply the scale factor to the attention matrix only for text exceeding the max sequence length, without any pretraining; $w256$ denotes the rerope windopw $w=256$.
For installation instructions, please refer to the UCM's top-level README. Once UCM is installed, ReRoPE is naturally supported by running the following example python scripts.
@@ -31,7 +31,7 @@ CacheBlend reduces TTFT by 2.2 ~ 3.3× and increases throughput by 2.8 ~ 5× und
31
31
1.**🔐 Chunk Hash Encoding**: Similar as prefix hash encoder, hash all blocks in each chunk from the same hash meta beginning.
32
32
2.**⚡ Combine Prefix Cache and Chunk Cache**: Since chunk cache and native prefix cache share the same hash space, ucm first performs prefix cache lookup to fetch fully reused cache and then conduct chunk cache lookup to fetch the candidate cache for blending.
33
33
3.**🎯 Delta-Rope PostProcess**: Rectify loaded chunk cache according to their position in the new request.
34
-
3.**🔍 Integrate Cache Blend and First Token Generation**: Construct compute mask and attention meta according to HKVD tokens, cache miss tokens and suffix tokens, then compute their kv cache in a single model forward stage.
34
+
3.**🔍 Integrate Cache Blend and First Token Generation**: Construct compute mask and attention meta according to the HKVD tokens, cache miss tokens and suffix tokens, then compute their kv cache in a single model forward stage.
35
35
4.**🚀 Comprehensive Hook for LLM Forward Pipeline**: Based on ucm sparse module, blend module sparse the prefill tokens not only in attention stage but also in ffn, layer stage.
0 commit comments