Skip to content

Commit 59a89ea

Browse files
flow: add GQADynamicEntropySparseAttention (confidence-gated hybrid of centroid and QUEST paths)\n\n- Introduce dynamic gating using L2 norm of per-page softmaxed block scores over query groups (entropy proxy)\n- Merge gated block score with QUEST envelope score via elementwise max before topK\n- Reuse existing ops; no new kernels required\n\nCo-authored-by: openhands <[email protected]>
1 parent b0f3cd6 commit 59a89ea

File tree

1 file changed

+69
-1
lines changed

1 file changed

+69
-1
lines changed

vortex_torch/flow/algorithms.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Dict
33

44
from .flow import vFlow
5-
from ..indexer import topK, GeMV, Softmax, Max, Sum, GeMM, Maximum, Multiply
5+
from ..indexer import topK, GeMV, Softmax, Max, Sum, GeMM, Maximum, Multiply, L2Norm
66
from ..cache import Mean as CMean, Max as CMax, Min as CMin
77
from ..abs import ContextBase
88
from .registry import register
@@ -571,3 +571,71 @@ def create_cache(self, page_size: int, head_dim: int):
571571
"min": (1, head_dim),
572572
}
573573

574+
575+
576+
@register("gqa_dynamic_entropy_sparse_attention")
577+
class GQADynamicEntropySparseAttention(vFlow):
578+
"""
579+
Dynamic-entropy gated sparse attention.
580+
581+
Two scoring paths per page:
582+
1) Centroid path with softmax over pages; its query-group energy
583+
(L2 norm across queries) forms a confidence gate.
584+
2) QUEST-style envelope path (max/min upper bound).
585+
586+
We scale the centroid score by the confidence and then take an
587+
elementwise max with the QUEST score before top-k.
588+
"""
589+
def __init__(self):
590+
super().__init__()
591+
# Block path
592+
self.gemm = GeMM()
593+
self.softmax = Softmax(dim=0, scale=0.09)
594+
self.max_over_heads = Max(dim=2)
595+
self.l2_over_queries = L2Norm(dim=1)
596+
self.mul = Multiply()
597+
# QUEST path
598+
self.mul_max = Multiply()
599+
self.mul_min = Multiply()
600+
self.maximum = Maximum()
601+
self.sum_over_dim = Sum(dim=2)
602+
self.max_over_queries = Max(dim=1)
603+
# Merge + output
604+
self.merge = Maximum()
605+
self.output_func = topK()
606+
# Cache reductions
607+
self.reduction_mean = CMean(dim=1)
608+
self.reduction_max = CMax(dim=1)
609+
self.reduction_min = CMin(dim=1)
610+
611+
def forward_indexer(self, q, o, cache: Dict[str, torch.Tensor], ctx: ContextBase):
612+
# Block scoring and confidence
613+
score_block = self.gemm(q, cache["centroids"], ctx=ctx)
614+
self.softmax(score_block, ctx=ctx)
615+
aggr_block = self.max_over_heads(score_block, ctx=ctx)
616+
conf = self.l2_over_queries(score_block, ctx=ctx)
617+
gated_block = self.mul(aggr_block, conf, ctx=ctx)
618+
619+
# QUEST scoring
620+
s_max = self.mul_max(q, cache["max"], ctx=ctx)
621+
s_min = self.mul_min(q, cache["min"], ctx=ctx)
622+
s = self.maximum(s_max, s_min, ctx=ctx)
623+
score_quest = self.sum_over_dim(s, ctx=ctx)
624+
aggr_quest = self.max_over_queries(score_quest, ctx=ctx)
625+
626+
# Merge and select
627+
combined = self.merge(gated_block, aggr_quest, ctx=ctx)
628+
self.output_func(combined, o, ctx=ctx)
629+
630+
def forward_cache(self, cache: Dict[str, torch.Tensor], loc: torch.Tensor, ctx: ContextBase):
631+
self.reduction_mean(cache["k"], cache["centroids"], loc=loc, ctx=ctx)
632+
self.reduction_max(cache["k"], cache["max"], loc=loc, ctx=ctx)
633+
self.reduction_min(cache["k"], cache["min"], loc=loc, ctx=ctx)
634+
635+
def create_cache(self, page_size: int, head_dim: int):
636+
return {
637+
"centroids": (1, head_dim),
638+
"max": (1, head_dim),
639+
"min": (1, head_dim),
640+
}
641+

0 commit comments

Comments
 (0)