Skip to content

Commit 64957e8

Browse files
delete gqa
1 parent b0f3cd6 commit 64957e8

File tree

3 files changed

+68
-138
lines changed

3 files changed

+68
-138
lines changed

docs/index.rst

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,33 @@ Installation
88

99
.. code-block:: bash
1010
11-
pip install vortex-torch
11+
git clone https://github.com/Infini-AI-Lab/vortex_torch.git
12+
cd vortex_torch
13+
pip install -e .
1214
1315
Quick Example
1416
-------------
15-
1617
.. code-block:: python
1718
18-
import vortex_torch as vt
19+
20+
21+
.. code-block:: python
1922
20-
model = vt.Model(...)
21-
out = model.forward(...)
23+
llm = sgl.Engine(model_path="Qwen/Qwen3-0.6B",
24+
disable_cuda_graph=False,
25+
page_size=16,
26+
vortex_topk_val=30,
27+
disable_overlap_schedule=True,
28+
attention_backend="flashinfer",
29+
enable_vortex_sparsity=True,
30+
vortex_page_reserved_bos=1,
31+
vortex_page_reserved_eos=1,
32+
vortex_layers_skip=list(range(1)),
33+
vortex_module_path="path/to/custom_sparse_attention.py"
34+
vortex_module_name="custom_sparse_attention",
35+
vortex_max_seq_lens=8192,
36+
mem_fraction_static=0.6
37+
)
2238
2339
API Reference
2440
-------------

vortex_torch/flow/algorithms.py

Lines changed: 31 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,14 @@ class BlockSparseAttention(vFlow):
4343
4444
.. math::
4545
46-
o \in \mathbb{R}^{S_{\mathrm{sparse}} \times 1 \times 1},
46+
o \in \mathbb{R}^{S} \times 1 \times 1},
4747
48-
where :math:`S_{\mathrm{sparse}}` is a packed sparse page axis
49-
as described in :class:`vFlow`.
48+
Here :math:`S` is the leading page axis. Internally it is a packed
49+
axis (often denoted :math:`S_{\mathrm{pack}}`), obtained by
50+
concatenating the pages from all requests. As a user, you can simply
51+
think of :math:`S` as "the number of pages for this request"; the
52+
vFlow kernels and :class:`ContextBase` will take care of mapping
53+
between per-request page counts and the packed layout automatically.
5054
5155
Cache layout
5256
------------
@@ -68,7 +72,7 @@ class BlockSparseAttention(vFlow):
6872
.. math::
6973
7074
\text{cache["centroids"]} \sim
71-
\mathbb{R}^{S_{\mathrm{pack}} \times 1 \times D},
75+
\mathbb{R}^{S} \times 1 \times D},
7276
7377
- In :meth:`forward_cache` (batch-major view):
7478
@@ -127,8 +131,7 @@ def forward_indexer(
127131
- ``cache["k"]`` and ``cache["v"]`` are page-packed key/value
128132
tensors,
129133
- ``cache["centroids"]`` is interpreted as
130-
``[S_pack, 1, D]`` (page-packed centroids), with
131-
:math:`S_{\mathrm{pack}} = \sum_i S_i`.
134+
``[S, 1, D]`` (page-packed centroids).
132135
133136
ctx : ContextBase
134137
Runtime context carrying page layout, top-k configuration
@@ -214,7 +217,7 @@ def create_cache(self, page_size: int, head_dim: int):
214217
- ``"centroids"`` with inner shape ``(1, head_dim)``, which
215218
becomes
216219
217-
- ``[S_pack, 1, head_dim]`` in :meth:`forward_indexer`,
220+
- ``[S, 1, head_dim]`` in :meth:`forward_indexer`,
218221
- ``[B, 1, head_dim]`` in :meth:`forward_cache`.
219222
"""
220223
return {
@@ -235,9 +238,15 @@ class GQABlockSparseAttention(vFlow):
235238
- Centroids cache ``cache["centroids"]`` has inner shape
236239
``(1, head_dim)`` and is viewed as:
237240
238-
- ``[S_pack, 1, D]`` in :meth:`forward_indexer`,
241+
- ``[S, 1, D]`` in :meth:`forward_indexer`,
239242
- ``[B, 1, D]`` in :meth:`forward_cache`.
240-
243+
Here :math:`S` is the leading page axis. Internally it is a packed
244+
axis (often denoted :math:`S_{\mathrm{pack}}`), obtained by
245+
concatenating the pages from all requests. As a user, you can simply
246+
think of :math:`S` as "the number of pages for this request"; the
247+
vFlow kernels and :class:`ContextBase` will take care of mapping
248+
between per-request page counts and the packed layout automatically.
249+
241250
For a design similar in spirit to grouped-query block sparsity, see
242251
the GQA sparse attention formulation in:
243252
@@ -270,8 +279,8 @@ def forward_indexer(
270279
1. Apply :class:`GeMM` between queries and centroids:
271280
272281
- ``q``: ``[B, H_q, D]``
273-
- ``cache["centroids"]`` (indexer view): ``[S_pack, 1, D]``
274-
- ``score``: ``[S_pack, H_q, 1]`` (logical ``[S, Ny, Nx]``)
282+
- ``cache["centroids"]`` (indexer view): ``[S, 1, D]``
283+
- ``score``: ``[S, H_q, 1]`` (logical ``[S, Ny, Nx]``)
275284
276285
2. Apply in-place softmax over the leading (page) axis with a
277286
scaling factor ``scale``:
@@ -352,12 +361,19 @@ class GQAQuestSparseAttention(vFlow):
352361
- ``cache["max"]`` and ``cache["min"]``: ``(1, head_dim)``
353362
→ viewed as
354363
355-
- ``[S_pack, 1, D]`` in :meth:`forward_indexer`,
364+
- ``[S, 1, D]`` in :meth:`forward_indexer`,
356365
- ``[B, 1, D]`` in :meth:`forward_cache`.
357366
358367
- ``cache["k"]``: standard key cache with inner shape
359368
``(page_size, head_dim)``.
360369
370+
Here :math:`S` is the leading page axis. Internally it is a packed
371+
axis (often denoted :math:`S_{\mathrm{pack}}`), obtained by
372+
concatenating the pages from all requests. As a user, you can simply
373+
think of :math:`S` as "the number of pages for this request"; the
374+
vFlow kernels and :class:`ContextBase` will take care of mapping
375+
between per-request page counts and the packed layout automatically.
376+
361377
Routing intuition
362378
-----------------
363379
For each query and page envelope:
@@ -401,15 +417,15 @@ def forward_indexer(
401417
Let:
402418
403419
- ``q``: ``[B, H_q, D]``
404-
- ``cache["max"]``: ``[S_pack, 1, D]``
405-
- ``cache["min"]``: ``[S_pack, 1, D]``
420+
- ``cache["max"]``: ``[S, 1, D]``
421+
- ``cache["min"]``: ``[S, 1, D]``
406422
407423
Steps:
408424
409425
1. ``s_max = q * max_envelope``
410426
2. ``s_min = q * min_envelope``
411427
3. ``s = max(s_max, s_min)`` (elementwise)
412-
4. ``score = sum(s, dim=D)`` → ``[S_pack, H_q, 1]``
428+
4. ``score = sum(s, dim=D)`` → ``[S, H_q, 1]``
413429
5. ``aggr_score = max(score, dim=H_q)`` → per-page scalar
414430
6. :class:`topK` converts ``aggr_score`` into sparse page
415431
indices ``o`` of shape ``[S_sparse, 1, 1]``.
@@ -468,106 +484,3 @@ def create_cache(self, page_size: int, head_dim: int):
468484
"max": (1, head_dim),
469485
"min": (1, head_dim),
470486
}
471-
472-
473-
474-
# Generated by GPT5.1
475-
@register("gqa_dynamic_hybrid_sparse_attention")
476-
class GQADynamicHybridSparseAttention(vFlow):
477-
"""
478-
Dynamic hybrid sparse attention:
479-
- Maintains mean, max, and min statistics per block.
480-
- Uses a block-sparse (centroid-based) score path.
481-
- Uses a Quest-style (max/min) score path.
482-
- Combines them via element-wise max as a dynamic gating signal.
483-
"""
484-
485-
def __init__(self):
486-
super().__init__()
487-
488-
# ----- indexer ops -----
489-
# Block-style scoring
490-
self.gemm = GeMM()
491-
self.softmax = Softmax(dim=0, scale=0.09)
492-
self.max_over_heads = Max(dim=2) # same as GQABlockSparseAttention
493-
494-
# Quest-style scoring
495-
self.mul_max = Multiply()
496-
self.mul_min = Multiply()
497-
self.max_elementwise = Maximum()
498-
self.sum_over_dim = Sum(dim=2) # same as GQAQuestSparseAttention
499-
self.max_over_queries = Max(dim=1)
500-
501-
# Combine block + quest scores
502-
self.merge_scores = Maximum() # element-wise max between the two scores
503-
504-
# Final selection
505-
self.output_func = topK()
506-
507-
# ----- cache ops -----
508-
self.reduction_mean = CMean(dim=1) # centroids
509-
self.reduction_max = CMax(dim=1) # per-dim max
510-
self.reduction_min = CMin(dim=1) # per-dim min
511-
512-
def forward_indexer(self, q, o, cache: Dict[str, torch.Tensor], ctx: ContextBase):
513-
"""
514-
q: query tensor (GQA-packed)
515-
o: indexer output tensor (indices / scores buffer for topK)
516-
cache: contains "centroids", "max", "min"
517-
"""
518-
519-
# ---- 1. Block-style centroid scoring ----
520-
# score_block: [*, *, num_blocks] (same shape as in GQABlockSparseAttention)
521-
score_block = self.gemm(q, cache["centroids"], ctx=ctx)
522-
self.softmax(score_block, ctx=ctx)
523-
# Aggregate over heads → [*, num_blocks]
524-
aggr_block = self.max_over_heads(score_block, ctx=ctx)
525-
526-
# ---- 2. Quest-style max/min gating ----
527-
# Element-wise products with cached max/min stats
528-
s_max = self.mul_max(q, cache["max"], ctx=ctx)
529-
s_min = self.mul_min(q, cache["min"], ctx=ctx)
530-
531-
# Take the element-wise max between the two projections
532-
s = self.max_elementwise(s_max, s_min, ctx=ctx)
533-
534-
# Sum over feature dimension → [num_queries, num_heads, num_blocks]
535-
score_quest = self.sum_over_dim(s, ctx=ctx)
536-
537-
# Aggregate over queries → [num_heads, num_blocks] or [*, num_blocks]
538-
aggr_quest = self.max_over_queries(score_quest, ctx=ctx)
539-
540-
# ---- 3. Dynamic merge ----
541-
# For each block, take whichever score (block vs quest) is stronger
542-
# This yields a per-block dynamic gate.
543-
combined_score = self.merge_scores(aggr_block, aggr_quest, ctx=ctx)
544-
545-
# ---- 4. Top-K block selection ----
546-
self.output_func(combined_score, o, ctx=ctx)
547-
548-
def forward_cache(self, cache: Dict[str, torch.Tensor], loc: torch.Tensor, ctx: ContextBase):
549-
"""
550-
cache["k"]: full key buffer for the page
551-
loc: index of the page / block being updated
552-
"""
553-
554-
# Update mean (centroids)
555-
self.reduction_mean(cache["k"], cache["centroids"], loc=loc, ctx=ctx)
556-
557-
# Update per-dimension maxima and minima
558-
self.reduction_max(cache["k"], cache["max"], loc=loc, ctx=ctx)
559-
self.reduction_min(cache["k"], cache["min"], loc=loc, ctx=ctx)
560-
561-
def create_cache(self, page_size: int, head_dim: int):
562-
"""
563-
For each block/page we maintain:
564-
- centroids: mean key per dimension
565-
- max: max key per dimension
566-
- min: min key per dimension
567-
"""
568-
return {
569-
"centroids": (1, head_dim),
570-
"max": (1, head_dim),
571-
"min": (1, head_dim),
572-
}
573-

vortex_torch/flow/flow.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,20 @@ class vFlow(ABC):
8181
.. math::
8282
8383
\text{cache[key]} \sim
84-
\mathbb{R}^{S_{\text{pack}} \times r \times c},
84+
\mathbb{R}^{S \times r \times c},
8585
86-
where
86+
8787
88-
.. math::
89-
90-
S_{\text{pack}} = \sum_{i=0}^{B-1} S_i
91-
92-
is the total number of pages packed across all requests, and
9388
:math:`(r, c)` is the per-key inner shape declared via
9489
:meth:`create_cache` or implicitly for ``"k"``/``"v"``.
9590
91+
Here :math:`S` is the leading page axis. Internally it is a packed
92+
axis (often denoted :math:`S_{\mathrm{pack}}`), obtained by
93+
concatenating the pages from all requests. As a user, you can simply
94+
think of :math:`S` as "the number of pages for this request"; the
95+
vFlow kernels and :class:`ContextBase` will take care of mapping
96+
between per-request page counts and the packed layout automatically.
97+
9698
2. **Cache-update view (batch-major)** — used in :meth:`forward_cache`:
9799
98100
.. math::
@@ -140,7 +142,7 @@ class vFlow(ABC):
140142
{\text{page_size} \cdot \text{head_dim}}.
141143
142144
This ignores the leading dimension (whether :math:`B` or
143-
:math:`S_{\text{pack}}`) and compares only inner shapes to the
145+
:math:`S`) and compares only inner shapes to the
144146
baseline ``(page_size, head_dim)``.
145147
146148
Subclass responsibilities
@@ -149,7 +151,7 @@ class vFlow(ABC):
149151
150152
- :meth:`forward_indexer(q, o, cache, ctx)`:
151153
compute sparse page indices (or routing scores) from queries,
152-
using cache in the :math:`S_{\text{pack}}` view.
154+
using cache in the :math:`S` view.
153155
154156
- :meth:`forward_cache(cache, loc, ctx)`:
155157
update cache tensors using the :math:`B`-major view and positional
@@ -203,9 +205,8 @@ def forward_indexer(
203205
.. math::
204206
205207
\text{cache[key]}
206-
\sim \mathbb{R}^{S_{\text{pack}} \times r \times c},
208+
\sim \mathbb{R}^{S \times r \times c},
207209
208-
where :math:`S_{\text{pack}} = \sum_i S_i` and
209210
:math:`(r, c)` are the per-key inner dimensions obtained from
210211
:meth:`get_cache_meta_info`.
211212
@@ -219,7 +220,7 @@ def forward_indexer(
219220
--------
220221
Implementations should:
221222
222-
- interpret ``cache`` in the :math:`S_{\text{pack}}` view,
223+
- interpret ``cache`` in the :math:`S` view,
223224
- use ``q`` and relevant cache tensors to score/select pages,
224225
- respect per-request bounds derived from ``ctx``,
225226
- write the resulting sparse indices (or routing representation)
@@ -291,7 +292,7 @@ def create_cache(
291292
292293
This method **does not allocate** tensors. It only declares the
293294
per-key inner dimensions :math:`(r, c)`; the runtime will attach
294-
the appropriate leading axis (:math:`B` or :math:`S_{\text{pack}}`)
295+
the appropriate leading axis (:math:`B` or :math:`S`)
295296
depending on whether the cache is used in :meth:`forward_cache`
296297
or :meth:`forward_indexer`.
297298
@@ -357,7 +358,7 @@ def get_cache_meta_info(
357358
Dict[str, Tuple[int, int]]
358359
Mapping from cache tensor names to inner shapes ``(r, c)``.
359360
The runtime will later prepend either a batch axis ``B`` or a
360-
packed-page axis ``S_pack`` when materializing the tensors.
361+
packed-page axis ``S`` when materializing the tensors.
361362
362363
Raises
363364
------
@@ -392,7 +393,7 @@ def get_token_ratio(self, page_size: int, head_dim: int) -> float:
392393
\frac{r_{\text{key}} \cdot c_{\text{key}}}
393394
{\text{page_size} \cdot \text{head_dim}}.
394395
395-
The leading dimension (:math:`B` or :math:`S_{\text{pack}}`) is
396+
The leading dimension (:math:`B` or :math:`S`) is
396397
not included in this ratio on purpose; it is a per-page
397398
normalization.
398399

0 commit comments

Comments
 (0)