Skip to content

Commit 17b29a1

Browse files
1116
1 parent 8a7656c commit 17b29a1

File tree

7 files changed

+84
-13
lines changed

7 files changed

+84
-13
lines changed

vortex_torch/cache/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
from .context import Context
2-
from .reduce import Mean
2+
from .reduce import Mean, Max, Min
3+
from .matmul import GeMM
4+
from .elementwise import Relu, Silu, Sigmoid, Abs, Add_Mul
5+
from .elementwise_binary import Maximum, Minimum, Multiply, Add
36
from .triton_kernels import set_kv_buffer_launcher
47

58

69
__all__ = [
710
"set_kv_buffer_launcher",
8-
"Mean",
11+
"Mean", "Max", "Min",
12+
"GeMM",
13+
"Relu", "Silu", "Sigmoid", "Abs", "Add_Mul",
14+
"Maximum", "Minimum", "Multiply", "Add",
915
"Context"
1016
]
1117

vortex_torch/cache/reduce.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,8 @@ def __init__(self, dim = 1):
171171
super().__init__(dim)
172172
self.reduce_type = ReduceType.L2Norm
173173

174+
class Sum(Reduce):
175+
176+
def __init__(self, dim = 1):
177+
super().__init__(dim)
178+
self.reduce_type = ReduceType.Sum

vortex_torch/flow/algorithms.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from typing import Dict
33

44
from .flow import vFlow
5-
from ..indexer import topK, GeMV, Softmax, Max, GeMM
6-
from ..cache import Mean
5+
from ..indexer import topK, GeMV, Softmax, Max, Sum, GeMM, Maximum, Multiply
6+
from ..cache import Mean as CMean, Max as CMax, Min as CMin
77
from ..abs import ContextBase
88
from .registry import register
99

@@ -12,10 +12,12 @@ class BlockSparseAttention(vFlow):
1212

1313
def __init__(self):
1414
super().__init__()
15-
15+
#indexer ops
1616
self.gemv = GeMV()
1717
self.output_func = topK()
18-
self.reduction = Mean(dim=1)
18+
19+
#cache ops
20+
self.reduction = CMean(dim=1)
1921

2022
def forward_indexer(self, q, o, cache, ctx):
2123

@@ -40,12 +42,14 @@ class GQABlockSparseAttention(vFlow):
4042

4143
def __init__(self):
4244
super().__init__()
43-
45+
#indexer ops
4446
self.gemm = GeMM()
4547
self.softmax = Softmax(dim=0, scale=0.09)
4648
self.max_op = Max(dim=2)
4749
self.output_func = topK()
48-
self.reduction = Mean(dim=1)
50+
51+
#cache ops
52+
self.reduction = CMean(dim=1)
4953

5054
def forward_indexer(self, q, o, cache, ctx):
5155

@@ -64,4 +68,46 @@ def create_cache(self, page_size: int, head_dim: int):
6468
return {
6569
"centroids": (1, head_dim)
6670
}
67-
71+
72+
73+
74+
@register("gqa_quest_sparse_attention")
75+
class GQAQuestSparseAttention(vFlow):
76+
77+
def __init__(self):
78+
super().__init__()
79+
80+
#indexer ops
81+
self.mul_max = Multiply()
82+
self.mul_min = Multiply()
83+
self.maximum_op = Maximum()
84+
self.sum = Sum(dim=2)
85+
self.max_op = Max(dim=1)
86+
self.output_func = topK()
87+
88+
#cache ops
89+
self.reduction_max = CMax(dim=1)
90+
self.reduction_min = CMin(dim=1)
91+
92+
def forward_indexer(self, q, o, cache, ctx):
93+
94+
s_max = self.mul_max(q, cache["max"], ctx=ctx)
95+
s_min = self.mul_min(q, cache["min"], ctx=ctx)
96+
s = self.maximum_op(s_max, s_min, ctx=ctx)
97+
score = self.sum(s, ctx=ctx)
98+
aggr_score = self.max_op(score, ctx=ctx)
99+
self.output_func(aggr_score, o, ctx=ctx)
100+
101+
def forward_cache(self, cache: Dict[str, torch.Tensor], loc:torch.Tensor, ctx: ContextBase):
102+
103+
self.reduction_max(cache["k"], cache["max"], loc=loc, ctx=ctx)
104+
self.reduction_min(cache["k"], cache["min"], loc=loc, ctx=ctx)
105+
106+
107+
108+
def create_cache(self, page_size: int, head_dim: int):
109+
110+
return {
111+
"max": (1, head_dim),
112+
"min": (1, head_dim)
113+
}

vortex_torch/indexer/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .matmul import GeMV, GeMM
22
from .output_func import topK
3-
from .reduce import Max, Mean, Min, L2Norm
3+
from .reduce import Max, Mean, Min, L2Norm, Sum
44
from .scan import Softmax, Normalize
55
from .transpose import Transpose
66
from .elementwise_binary import Maximum, Minimum, Multiply, Add
@@ -10,7 +10,7 @@
1010
__all__ = [
1111
"GeMV", "GeMM",
1212
"topK",
13-
"Max", "Mean", "Min", "L2Norm",
13+
"Max", "Mean", "Min", "L2Norm", "Sum",
1414
"Softmax", "Normalize",
1515
"Transpose",
1616
"Maximum", "Minimum", "Multiply", "Add",

vortex_torch/indexer/reduce.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,10 @@ class L2Norm(Reduce):
114114
def __init__(self, dim = 1):
115115
super().__init__(dim)
116116
self.reduce_type = ReduceType.L2Norm
117+
118+
119+
class Sum(Reduce):
120+
121+
def __init__(self, dim = 1):
122+
super().__init__(dim)
123+
self.reduce_type = ReduceType.Sum

vortex_torch/indexer/triton_kernels/reduce_impl.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,10 @@ def reduce_rr_kernel(
5959

6060
elif REDUCE_TYPE == 3:
6161
x_i_reduce = tl.sqrt(tl.sum(x_i * x_i, axis=1))
62-
62+
63+
elif REDUCE_TYPE == 4:
64+
x_i_reduce = tl.sum(x_i, axis=1)
65+
6366
else:
6467
x_i_reduce = tl.zeros((max_chunk_size, x_D1), dtype=tl.bfloat16)
6568

@@ -80,7 +83,10 @@ def reduce_rr_kernel(
8083

8184
elif REDUCE_TYPE == 3:
8285
x_i_reduce = tl.sqrt(tl.sum(x_i * x_i, axis=2))
83-
86+
87+
elif REDUCE_TYPE == 4:
88+
x_i_reduce = tl.sum(x_i, axis=2)
89+
8490
else:
8591
x_i_reduce = tl.zeros((max_chunk_size, x_D1), dtype=tl.float32)
8692

vortex_torch/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class ReduceType(Enum):
1414
Max = 1
1515
Min = 2
1616
L2Norm = 3
17+
Sum = 4
1718

1819

1920
class ElementwiseBinaryOpType(Enum):

0 commit comments

Comments
 (0)