22from typing import Dict
33
44from .flow import vFlow
5- from ..indexer import topK , GeMV , Softmax , Max , GeMM
6- from ..cache import Mean
5+ from ..indexer import topK , GeMV , Softmax , Max , Sum , GeMM , Maximum , Multiply
6+ from ..cache import Mean as CMean , Max as CMax , Min as CMin
77from ..abs import ContextBase
88from .registry import register
99
@@ -12,10 +12,12 @@ class BlockSparseAttention(vFlow):
1212
1313 def __init__ (self ):
1414 super ().__init__ ()
15-
15+ #indexer ops
1616 self .gemv = GeMV ()
1717 self .output_func = topK ()
18- self .reduction = Mean (dim = 1 )
18+
19+ #cache ops
20+ self .reduction = CMean (dim = 1 )
1921
2022 def forward_indexer (self , q , o , cache , ctx ):
2123
@@ -40,12 +42,14 @@ class GQABlockSparseAttention(vFlow):
4042
4143 def __init__ (self ):
4244 super ().__init__ ()
43-
45+ #indexer ops
4446 self .gemm = GeMM ()
4547 self .softmax = Softmax (dim = 0 , scale = 0.09 )
4648 self .max_op = Max (dim = 2 )
4749 self .output_func = topK ()
48- self .reduction = Mean (dim = 1 )
50+
51+ #cache ops
52+ self .reduction = CMean (dim = 1 )
4953
5054 def forward_indexer (self , q , o , cache , ctx ):
5155
@@ -64,4 +68,46 @@ def create_cache(self, page_size: int, head_dim: int):
6468 return {
6569 "centroids" : (1 , head_dim )
6670 }
67-
71+
72+
73+
74+ @register ("gqa_quest_sparse_attention" )
75+ class GQAQuestSparseAttention (vFlow ):
76+
77+ def __init__ (self ):
78+ super ().__init__ ()
79+
80+ #indexer ops
81+ self .mul_max = Multiply ()
82+ self .mul_min = Multiply ()
83+ self .maximum_op = Maximum ()
84+ self .sum = Sum (dim = 2 )
85+ self .max_op = Max (dim = 1 )
86+ self .output_func = topK ()
87+
88+ #cache ops
89+ self .reduction_max = CMax (dim = 1 )
90+ self .reduction_min = CMin (dim = 1 )
91+
92+ def forward_indexer (self , q , o , cache , ctx ):
93+
94+ s_max = self .mul_max (q , cache ["max" ], ctx = ctx )
95+ s_min = self .mul_min (q , cache ["min" ], ctx = ctx )
96+ s = self .maximum_op (s_max , s_min , ctx = ctx )
97+ score = self .sum (s , ctx = ctx )
98+ aggr_score = self .max_op (score , ctx = ctx )
99+ self .output_func (aggr_score , o , ctx = ctx )
100+
101+ def forward_cache (self , cache : Dict [str , torch .Tensor ], loc :torch .Tensor , ctx : ContextBase ):
102+
103+ self .reduction_max (cache ["k" ], cache ["max" ], loc = loc , ctx = ctx )
104+ self .reduction_min (cache ["k" ], cache ["min" ], loc = loc , ctx = ctx )
105+
106+
107+
108+ def create_cache (self , page_size : int , head_dim : int ):
109+
110+ return {
111+ "max" : (1 , head_dim ),
112+ "min" : (1 , head_dim )
113+ }
0 commit comments