@@ -43,10 +43,14 @@ class BlockSparseAttention(vFlow):
4343
4444 .. math::
4545
46- o \in \mathbb{R}^{S_{\mathrm{sparse} } \times 1 \times 1},
46+ o \in \mathbb{R}^{S } \times 1 \times 1},
4747
48- where :math:`S_{\mathrm{sparse}}` is a packed sparse page axis
49- as described in :class:`vFlow`.
48+ Here :math:`S` is the leading page axis. Internally it is a packed
49+ axis (often denoted :math:`S_{\mathrm{pack}}`), obtained by
50+ concatenating the pages from all requests. As a user, you can simply
51+ think of :math:`S` as "the number of pages for this request"; the
52+ vFlow kernels and :class:`ContextBase` will take care of mapping
53+ between per-request page counts and the packed layout automatically.
5054
5155 Cache layout
5256 ------------
@@ -68,7 +72,7 @@ class BlockSparseAttention(vFlow):
6872 .. math::
6973
7074 \text{cache["centroids"]} \sim
71- \mathbb{R}^{S_{\mathrm{pack} } \times 1 \times D},
75+ \mathbb{R}^{S } \times 1 \times D},
7276
7377 - In :meth:`forward_cache` (batch-major view):
7478
@@ -127,8 +131,7 @@ def forward_indexer(
127131 - ``cache["k"]`` and ``cache["v"]`` are page-packed key/value
128132 tensors,
129133 - ``cache["centroids"]`` is interpreted as
130- ``[S_pack, 1, D]`` (page-packed centroids), with
131- :math:`S_{\mathrm{pack}} = \sum_i S_i`.
134+ ``[S, 1, D]`` (page-packed centroids).
132135
133136 ctx : ContextBase
134137 Runtime context carrying page layout, top-k configuration
@@ -214,7 +217,7 @@ def create_cache(self, page_size: int, head_dim: int):
214217 - ``"centroids"`` with inner shape ``(1, head_dim)``, which
215218 becomes
216219
217- - ``[S_pack , 1, head_dim]`` in :meth:`forward_indexer`,
220+ - ``[S , 1, head_dim]`` in :meth:`forward_indexer`,
218221 - ``[B, 1, head_dim]`` in :meth:`forward_cache`.
219222 """
220223 return {
@@ -235,9 +238,15 @@ class GQABlockSparseAttention(vFlow):
235238 - Centroids cache ``cache["centroids"]`` has inner shape
236239 ``(1, head_dim)`` and is viewed as:
237240
238- - ``[S_pack , 1, D]`` in :meth:`forward_indexer`,
241+ - ``[S , 1, D]`` in :meth:`forward_indexer`,
239242 - ``[B, 1, D]`` in :meth:`forward_cache`.
240-
243+ Here :math:`S` is the leading page axis. Internally it is a packed
244+ axis (often denoted :math:`S_{\mathrm{pack}}`), obtained by
245+ concatenating the pages from all requests. As a user, you can simply
246+ think of :math:`S` as "the number of pages for this request"; the
247+ vFlow kernels and :class:`ContextBase` will take care of mapping
248+ between per-request page counts and the packed layout automatically.
249+
241250 For a design similar in spirit to grouped-query block sparsity, see
242251 the GQA sparse attention formulation in:
243252
@@ -270,8 +279,8 @@ def forward_indexer(
270279 1. Apply :class:`GeMM` between queries and centroids:
271280
272281 - ``q``: ``[B, H_q, D]``
273- - ``cache["centroids"]`` (indexer view): ``[S_pack , 1, D]``
274- - ``score``: ``[S_pack , H_q, 1]`` (logical ``[S, Ny, Nx]``)
282+ - ``cache["centroids"]`` (indexer view): ``[S , 1, D]``
283+ - ``score``: ``[S , H_q, 1]`` (logical ``[S, Ny, Nx]``)
275284
276285 2. Apply in-place softmax over the leading (page) axis with a
277286 scaling factor ``scale``:
@@ -352,12 +361,19 @@ class GQAQuestSparseAttention(vFlow):
352361 - ``cache["max"]`` and ``cache["min"]``: ``(1, head_dim)``
353362 → viewed as
354363
355- - ``[S_pack , 1, D]`` in :meth:`forward_indexer`,
364+ - ``[S , 1, D]`` in :meth:`forward_indexer`,
356365 - ``[B, 1, D]`` in :meth:`forward_cache`.
357366
358367 - ``cache["k"]``: standard key cache with inner shape
359368 ``(page_size, head_dim)``.
360369
370+ Here :math:`S` is the leading page axis. Internally it is a packed
371+ axis (often denoted :math:`S_{\mathrm{pack}}`), obtained by
372+ concatenating the pages from all requests. As a user, you can simply
373+ think of :math:`S` as "the number of pages for this request"; the
374+ vFlow kernels and :class:`ContextBase` will take care of mapping
375+ between per-request page counts and the packed layout automatically.
376+
361377 Routing intuition
362378 -----------------
363379 For each query and page envelope:
@@ -401,15 +417,15 @@ def forward_indexer(
401417 Let:
402418
403419 - ``q``: ``[B, H_q, D]``
404- - ``cache["max"]``: ``[S_pack , 1, D]``
405- - ``cache["min"]``: ``[S_pack , 1, D]``
420+ - ``cache["max"]``: ``[S , 1, D]``
421+ - ``cache["min"]``: ``[S , 1, D]``
406422
407423 Steps:
408424
409425 1. ``s_max = q * max_envelope``
410426 2. ``s_min = q * min_envelope``
411427 3. ``s = max(s_max, s_min)`` (elementwise)
412- 4. ``score = sum(s, dim=D)`` → ``[S_pack , H_q, 1]``
428+ 4. ``score = sum(s, dim=D)`` → ``[S , H_q, 1]``
413429 5. ``aggr_score = max(score, dim=H_q)`` → per-page scalar
414430 6. :class:`topK` converts ``aggr_score`` into sparse page
415431 indices ``o`` of shape ``[S_sparse, 1, 1]``.
@@ -468,106 +484,3 @@ def create_cache(self, page_size: int, head_dim: int):
468484 "max" : (1 , head_dim ),
469485 "min" : (1 , head_dim ),
470486 }
471-
472-
473-
474- # Generated by GPT5.1
475- @register ("gqa_dynamic_hybrid_sparse_attention" )
476- class GQADynamicHybridSparseAttention (vFlow ):
477- """
478- Dynamic hybrid sparse attention:
479- - Maintains mean, max, and min statistics per block.
480- - Uses a block-sparse (centroid-based) score path.
481- - Uses a Quest-style (max/min) score path.
482- - Combines them via element-wise max as a dynamic gating signal.
483- """
484-
485- def __init__ (self ):
486- super ().__init__ ()
487-
488- # ----- indexer ops -----
489- # Block-style scoring
490- self .gemm = GeMM ()
491- self .softmax = Softmax (dim = 0 , scale = 0.09 )
492- self .max_over_heads = Max (dim = 2 ) # same as GQABlockSparseAttention
493-
494- # Quest-style scoring
495- self .mul_max = Multiply ()
496- self .mul_min = Multiply ()
497- self .max_elementwise = Maximum ()
498- self .sum_over_dim = Sum (dim = 2 ) # same as GQAQuestSparseAttention
499- self .max_over_queries = Max (dim = 1 )
500-
501- # Combine block + quest scores
502- self .merge_scores = Maximum () # element-wise max between the two scores
503-
504- # Final selection
505- self .output_func = topK ()
506-
507- # ----- cache ops -----
508- self .reduction_mean = CMean (dim = 1 ) # centroids
509- self .reduction_max = CMax (dim = 1 ) # per-dim max
510- self .reduction_min = CMin (dim = 1 ) # per-dim min
511-
512- def forward_indexer (self , q , o , cache : Dict [str , torch .Tensor ], ctx : ContextBase ):
513- """
514- q: query tensor (GQA-packed)
515- o: indexer output tensor (indices / scores buffer for topK)
516- cache: contains "centroids", "max", "min"
517- """
518-
519- # ---- 1. Block-style centroid scoring ----
520- # score_block: [*, *, num_blocks] (same shape as in GQABlockSparseAttention)
521- score_block = self .gemm (q , cache ["centroids" ], ctx = ctx )
522- self .softmax (score_block , ctx = ctx )
523- # Aggregate over heads → [*, num_blocks]
524- aggr_block = self .max_over_heads (score_block , ctx = ctx )
525-
526- # ---- 2. Quest-style max/min gating ----
527- # Element-wise products with cached max/min stats
528- s_max = self .mul_max (q , cache ["max" ], ctx = ctx )
529- s_min = self .mul_min (q , cache ["min" ], ctx = ctx )
530-
531- # Take the element-wise max between the two projections
532- s = self .max_elementwise (s_max , s_min , ctx = ctx )
533-
534- # Sum over feature dimension → [num_queries, num_heads, num_blocks]
535- score_quest = self .sum_over_dim (s , ctx = ctx )
536-
537- # Aggregate over queries → [num_heads, num_blocks] or [*, num_blocks]
538- aggr_quest = self .max_over_queries (score_quest , ctx = ctx )
539-
540- # ---- 3. Dynamic merge ----
541- # For each block, take whichever score (block vs quest) is stronger
542- # This yields a per-block dynamic gate.
543- combined_score = self .merge_scores (aggr_block , aggr_quest , ctx = ctx )
544-
545- # ---- 4. Top-K block selection ----
546- self .output_func (combined_score , o , ctx = ctx )
547-
548- def forward_cache (self , cache : Dict [str , torch .Tensor ], loc : torch .Tensor , ctx : ContextBase ):
549- """
550- cache["k"]: full key buffer for the page
551- loc: index of the page / block being updated
552- """
553-
554- # Update mean (centroids)
555- self .reduction_mean (cache ["k" ], cache ["centroids" ], loc = loc , ctx = ctx )
556-
557- # Update per-dimension maxima and minima
558- self .reduction_max (cache ["k" ], cache ["max" ], loc = loc , ctx = ctx )
559- self .reduction_min (cache ["k" ], cache ["min" ], loc = loc , ctx = ctx )
560-
561- def create_cache (self , page_size : int , head_dim : int ):
562- """
563- For each block/page we maintain:
564- - centroids: mean key per dimension
565- - max: max key per dimension
566- - min: min key per dimension
567- """
568- return {
569- "centroids" : (1 , head_dim ),
570- "max" : (1 , head_dim ),
571- "min" : (1 , head_dim ),
572- }
573-
0 commit comments