|
2 | 2 | from typing import Dict |
3 | 3 |
|
4 | 4 | from .flow import vFlow |
5 | | -from ..indexer import topK, GeMV, Softmax, Max, Sum, GeMM, Maximum, Multiply |
| 5 | +from ..indexer import topK, GeMV, Softmax, Max, Sum, GeMM, Maximum, Multiply, L2Norm |
6 | 6 | from ..cache import Mean as CMean, Max as CMax, Min as CMin |
7 | 7 | from ..abs import ContextBase |
8 | 8 | from .registry import register |
@@ -571,3 +571,71 @@ def create_cache(self, page_size: int, head_dim: int): |
571 | 571 | "min": (1, head_dim), |
572 | 572 | } |
573 | 573 |
|
| 574 | + |
| 575 | + |
| 576 | +@register("gqa_dynamic_entropy_sparse_attention") |
| 577 | +class GQADynamicEntropySparseAttention(vFlow): |
| 578 | + """ |
| 579 | + Dynamic-entropy gated sparse attention. |
| 580 | +
|
| 581 | + Two scoring paths per page: |
| 582 | + 1) Centroid path with softmax over pages; its query-group energy |
| 583 | + (L2 norm across queries) forms a confidence gate. |
| 584 | + 2) QUEST-style envelope path (max/min upper bound). |
| 585 | +
|
| 586 | + We scale the centroid score by the confidence and then take an |
| 587 | + elementwise max with the QUEST score before top-k. |
| 588 | + """ |
| 589 | + def __init__(self): |
| 590 | + super().__init__() |
| 591 | + # Block path |
| 592 | + self.gemm = GeMM() |
| 593 | + self.softmax = Softmax(dim=0, scale=0.09) |
| 594 | + self.max_over_heads = Max(dim=2) |
| 595 | + self.l2_over_queries = L2Norm(dim=1) |
| 596 | + self.mul = Multiply() |
| 597 | + # QUEST path |
| 598 | + self.mul_max = Multiply() |
| 599 | + self.mul_min = Multiply() |
| 600 | + self.maximum = Maximum() |
| 601 | + self.sum_over_dim = Sum(dim=2) |
| 602 | + self.max_over_queries = Max(dim=1) |
| 603 | + # Merge + output |
| 604 | + self.merge = Maximum() |
| 605 | + self.output_func = topK() |
| 606 | + # Cache reductions |
| 607 | + self.reduction_mean = CMean(dim=1) |
| 608 | + self.reduction_max = CMax(dim=1) |
| 609 | + self.reduction_min = CMin(dim=1) |
| 610 | + |
| 611 | + def forward_indexer(self, q, o, cache: Dict[str, torch.Tensor], ctx: ContextBase): |
| 612 | + # Block scoring and confidence |
| 613 | + score_block = self.gemm(q, cache["centroids"], ctx=ctx) |
| 614 | + self.softmax(score_block, ctx=ctx) |
| 615 | + aggr_block = self.max_over_heads(score_block, ctx=ctx) |
| 616 | + conf = self.l2_over_queries(score_block, ctx=ctx) |
| 617 | + gated_block = self.mul(aggr_block, conf, ctx=ctx) |
| 618 | + |
| 619 | + # QUEST scoring |
| 620 | + s_max = self.mul_max(q, cache["max"], ctx=ctx) |
| 621 | + s_min = self.mul_min(q, cache["min"], ctx=ctx) |
| 622 | + s = self.maximum(s_max, s_min, ctx=ctx) |
| 623 | + score_quest = self.sum_over_dim(s, ctx=ctx) |
| 624 | + aggr_quest = self.max_over_queries(score_quest, ctx=ctx) |
| 625 | + |
| 626 | + # Merge and select |
| 627 | + combined = self.merge(gated_block, aggr_quest, ctx=ctx) |
| 628 | + self.output_func(combined, o, ctx=ctx) |
| 629 | + |
| 630 | + def forward_cache(self, cache: Dict[str, torch.Tensor], loc: torch.Tensor, ctx: ContextBase): |
| 631 | + self.reduction_mean(cache["k"], cache["centroids"], loc=loc, ctx=ctx) |
| 632 | + self.reduction_max(cache["k"], cache["max"], loc=loc, ctx=ctx) |
| 633 | + self.reduction_min(cache["k"], cache["min"], loc=loc, ctx=ctx) |
| 634 | + |
| 635 | + def create_cache(self, page_size: int, head_dim: int): |
| 636 | + return { |
| 637 | + "centroids": (1, head_dim), |
| 638 | + "max": (1, head_dim), |
| 639 | + "min": (1, head_dim), |
| 640 | + } |
| 641 | + |
0 commit comments