33import torch
44import torch .nn .functional as F
55import triton
6+ import triton .language as tl
67from ninetoothed import Symbol , Tensor
78
8- BLOCK_SIZE_M = Symbol ("BLOCK_SIZE_M" , meta = True )
9- BLOCK_SIZE_N = Symbol ("BLOCK_SIZE_N" , meta = True )
109
11- q = Tensor (2 , constexpr_shape = True )
12- k = Tensor (2 , constexpr_shape = True )
13- v = Tensor (2 , constexpr_shape = True )
14- o = Tensor (2 , constexpr_shape = True )
10+ def arrangement (q , k , v , o ):
11+ BLOCK_SIZE_M = Symbol ("BLOCK_SIZE_M" , meta = True )
12+ BLOCK_SIZE_N = Symbol ("BLOCK_SIZE_N" , meta = True )
1513
16- q_tiled = q .tile ((BLOCK_SIZE_M , - 1 ))
17- k_tiled = k .tile ((BLOCK_SIZE_N , - 1 )).tile ((- 1 , - 1 )).expand ((q_tiled .shape [0 ], - 1 ))
18- v_tiled = v .tile ((BLOCK_SIZE_N , - 1 )).tile ((- 1 , - 1 )).expand ((q_tiled .shape [0 ], - 1 ))
19- o_tiled = o .tile ((BLOCK_SIZE_M , - 1 ))
14+ def arrange_q_or_o (input ):
15+ arranged = input .tile ((1 , 1 , BLOCK_SIZE_M , - 1 ))
16+ arranged .dtype = arranged .dtype .squeeze ((0 , 1 ))
2017
18+ return arranged
2119
22- @ninetoothed .jit
23- def attention_kernel (q : q_tiled , k : k_tiled , v : v_tiled , o : o_tiled ):
20+ def arrange_k_or_v (input ):
21+ arranged = (
22+ input .tile ((1 , 1 , BLOCK_SIZE_N , - 1 ))
23+ .tile ((1 , 1 , - 1 , - 1 ))
24+ .expand ((- 1 , - 1 , q_arranged .shape [- 2 ], - 1 ))
25+ )
26+ arranged .dtype = arranged .dtype .squeeze ((0 , 1 , 3 ))
27+ arranged .dtype .dtype = arranged .dtype .dtype .squeeze ((0 , 1 ))
28+
29+ return arranged
30+
31+ q_arranged = arrange_q_or_o (q )
32+
33+ return q_arranged , arrange_k_or_v (k ), arrange_k_or_v (v ), arrange_q_or_o (o )
34+
35+
36+ def application (q , k , v , o ):
2437 acc = ntl .zeros ((q .shape [- 2 ], q .shape [- 1 ]), dtype = ntl .float32 )
2538 l_i = ntl .full ((q .shape [- 2 ],), 1 , dtype = ntl .float32 )
2639 m_i = ntl .full ((q .shape [- 2 ],), float ("-inf" ), dtype = ntl .float32 )
2740
2841 for i in range (k .shape [0 ]):
29- qk = ntl .dot ((q * 1.44269504089 ).to (ntl .float16 ), ntl .trans (k [i , 0 ]))
42+ qk = ntl .dot ((q * 1.44269504089 ).to (ntl .float16 ), ntl .trans (k [i ]))
3043
3144 m_ij = ntl .maximum (m_i , ntl .max (qk , 1 ))
3245 p = ntl .exp2 (qk - m_ij [:, None ])
3346 l_ij = ntl .sum (p , 1 )
3447
3548 alpha = ntl .exp2 (m_i - m_ij )
36- acc = acc * alpha [:, None ] + ntl .dot (p .to (ntl .float16 ), v [i , 0 ])
49+ acc = acc * alpha [:, None ] + ntl .dot (p .to (ntl .float16 ), v [i ])
3750 m_i = m_ij
3851 l_i = l_i * alpha + l_ij
3952
4053 acc /= l_i [:, None ]
41- o = acc .to (ntl .float32 ) # noqa: F841
54+ o = acc # noqa: F841
55+
56+
57+ q , k , v , o = (Tensor (4 , constexpr_shape = True ) for _ in range (4 ))
58+ attention_kernel = ninetoothed .make (arrangement , application , (q , k , v , o ))
4259
4360
4461def attention (q , k , v ):
@@ -49,59 +66,212 @@ def attention(q, k, v):
4966 return o
5067
5168
69+ @triton .autotune (
70+ configs = [
71+ triton .Config (
72+ {"BLOCK_SIZE_M" : 256 , "BLOCK_SIZE_N" : 128 }, num_stages = 4 , num_warps = 8
73+ ),
74+ triton .Config (
75+ {"BLOCK_SIZE_M" : 256 , "BLOCK_SIZE_N" : 64 }, num_stages = 4 , num_warps = 8
76+ ),
77+ triton .Config (
78+ {"BLOCK_SIZE_M" : 128 , "BLOCK_SIZE_N" : 128 }, num_stages = 4 , num_warps = 4
79+ ),
80+ triton .Config (
81+ {"BLOCK_SIZE_M" : 128 , "BLOCK_SIZE_N" : 64 }, num_stages = 4 , num_warps = 4
82+ ),
83+ triton .Config (
84+ {"BLOCK_SIZE_M" : 64 , "BLOCK_SIZE_N" : 64 }, num_stages = 4 , num_warps = 8
85+ ),
86+ triton .Config (
87+ {"BLOCK_SIZE_M" : 32 , "BLOCK_SIZE_N" : 32 }, num_stages = 4 , num_warps = 8
88+ ),
89+ ],
90+ key = ["EMB_DIM" ],
91+ )
92+ @triton .jit
93+ def triton_attention_kernel (
94+ q_ptr ,
95+ k_ptr ,
96+ v_ptr ,
97+ o_ptr ,
98+ q_stride_z ,
99+ q_stride_h ,
100+ q_stride_m ,
101+ q_stride_k ,
102+ k_stride_z ,
103+ k_stride_h ,
104+ k_stride_n ,
105+ k_stride_k ,
106+ v_stride_z ,
107+ v_stride_h ,
108+ v_stride_k ,
109+ v_stride_n ,
110+ o_stride_z ,
111+ o_stride_h ,
112+ o_stride_m ,
113+ o_stride_n ,
114+ SEQ_LEN : tl .constexpr ,
115+ EMB_DIM : tl .constexpr ,
116+ BLOCK_SIZE_M : tl .constexpr ,
117+ BLOCK_SIZE_N : tl .constexpr ,
118+ ):
119+ off_m = tl .program_id (0 )
120+ off_h = tl .program_id (1 )
121+ off_z = tl .program_id (2 )
122+
123+ offs_m_start = off_m * BLOCK_SIZE_M
124+
125+ q_off = off_z * q_stride_z + off_h * q_stride_h
126+ q_block_ptr = tl .make_block_ptr (
127+ base = q_ptr + q_off ,
128+ shape = (SEQ_LEN , EMB_DIM ),
129+ strides = (q_stride_m , q_stride_k ),
130+ offsets = (offs_m_start , 0 ),
131+ block_shape = (BLOCK_SIZE_M , EMB_DIM ),
132+ order = (1 , 0 ),
133+ )
134+ k_off = off_z * k_stride_z + off_h * k_stride_h
135+ k_block_ptr = tl .make_block_ptr (
136+ base = k_ptr + k_off ,
137+ shape = (EMB_DIM , SEQ_LEN ),
138+ strides = (k_stride_k , k_stride_n ),
139+ offsets = (0 , 0 ),
140+ block_shape = (EMB_DIM , BLOCK_SIZE_N ),
141+ order = (0 , 1 ),
142+ )
143+ v_off = off_z * v_stride_z + off_h * v_stride_h
144+ v_block_ptr = tl .make_block_ptr (
145+ base = v_ptr + v_off ,
146+ shape = (SEQ_LEN , EMB_DIM ),
147+ strides = (v_stride_k , v_stride_n ),
148+ offsets = (0 , 0 ),
149+ block_shape = (BLOCK_SIZE_N , EMB_DIM ),
150+ order = (1 , 0 ),
151+ )
152+ o_off = off_z * o_stride_z + off_h * o_stride_h
153+ o_block_ptr = tl .make_block_ptr (
154+ base = o_ptr + o_off ,
155+ shape = (SEQ_LEN , EMB_DIM ),
156+ strides = (o_stride_m , o_stride_n ),
157+ offsets = (offs_m_start , 0 ),
158+ block_shape = (BLOCK_SIZE_M , EMB_DIM ),
159+ order = (1 , 0 ),
160+ )
161+
162+ q = (tl .load (q_block_ptr ) * 1.44269504089 ).to (q_block_ptr .type .element_ty )
163+
164+ acc = tl .zeros ((BLOCK_SIZE_M , EMB_DIM ), dtype = tl .float32 )
165+ l_i = tl .full ((BLOCK_SIZE_M ,), 1 , dtype = tl .float32 )
166+ m_i = tl .full ((BLOCK_SIZE_M ,), float ("-inf" ), dtype = tl .float32 )
167+
168+ for _ in range (0 , tl .cdiv (SEQ_LEN , BLOCK_SIZE_N )):
169+ k = tl .load (k_block_ptr )
170+
171+ qk = tl .dot (q , k )
172+
173+ m_ij = tl .maximum (m_i , tl .max (qk , 1 ))
174+ qk -= m_ij [:, None ]
175+ p = tl .exp2 (qk )
176+ l_ij = tl .sum (p , 1 )
177+
178+ v = tl .load (v_block_ptr )
179+ alpha = tl .exp2 (m_i - m_ij )
180+ acc = acc * alpha [:, None ] + tl .dot (p .to (v_block_ptr .type .element_ty ), v )
181+ m_i = m_ij
182+ l_i = l_i * alpha + l_ij
183+
184+ v_block_ptr = tl .advance (v_block_ptr , (BLOCK_SIZE_N , 0 ))
185+ k_block_ptr = tl .advance (k_block_ptr , (0 , BLOCK_SIZE_N ))
186+
187+ acc /= l_i [:, None ]
188+
189+ tl .store (o_block_ptr , acc .to (o_ptr .type .element_ty ))
190+
191+
192+ def triton_attention (q , k , v ):
193+ o = torch .empty_like (q )
194+
195+ batch_size , num_heads , seq_len , emb_dim = q .shape
196+
197+ def grid (meta ):
198+ return (
199+ triton .cdiv (seq_len , meta ["BLOCK_SIZE_M" ]),
200+ num_heads ,
201+ batch_size ,
202+ )
203+
204+ triton_attention_kernel [grid ](
205+ q ,
206+ k ,
207+ v ,
208+ o ,
209+ * q .stride (),
210+ * k .stride (),
211+ * v .stride (),
212+ * o .stride (),
213+ SEQ_LEN = seq_len ,
214+ EMB_DIM = emb_dim ,
215+ )
216+
217+ return o
218+
219+
52220if __name__ == "__main__" :
53221 torch .manual_seed (0 )
54- shape = (1 , 1 , 1024 , 64 )
222+ shape = (2 , 4 , 1024 , 64 )
55223 dtype = torch .float16
56224 q = torch .randn (shape , dtype = dtype , device = "cuda" )
57225 k = torch .randn (shape , dtype = dtype , device = "cuda" )
58226 v = torch .randn (shape , dtype = dtype , device = "cuda" )
59227
60- ninetoothed_output = attention (
61- q .view (q .shape [- 2 :]), k .view (k .shape [- 2 :]), v .view (v .shape [- 2 :])
62- )
228+ ninetoothed_output = attention (q , k , v )
63229 torch_output = F .scaled_dot_product_attention (q , k , v , scale = 1 )
230+ triton_output = triton_attention (q , k , v )
64231 print (ninetoothed_output )
65232 print (torch_output )
66- if torch .allclose (ninetoothed_output , torch_output , atol = 0.01 , rtol = 0.01 ):
233+ print (triton_output )
234+ if torch .allclose (ninetoothed_output , torch_output , atol = 0.01 ):
67235 print ("✅ NineToothed and PyTorch match." )
68236 else :
69237 print ("❌ NineToothed and PyTorch differ." )
238+ if torch .allclose (ninetoothed_output , triton_output , atol = 0.01 ):
239+ print ("✅ NineToothed and Triton match." )
240+ else :
241+ print ("❌ NineToothed and Triton differ." )
70242
71243 @triton .testing .perf_report (
72244 triton .testing .Benchmark (
73- x_names = ["n " ],
245+ x_names = ["seq_len " ],
74246 x_vals = [2 ** i for i in range (10 , 15 )],
75247 line_arg = "provider" ,
76- line_vals = ["ninetoothed" , "torch" ],
77- line_names = ["NineToothed" , "PyTorch" ],
78- styles = [("blue" , "-" ), ("green" , "-" )],
248+ line_vals = ["ninetoothed" , "torch" , "triton" ],
249+ line_names = ["NineToothed" , "PyTorch" , "Triton" ],
250+ styles = [("blue" , "-" ), ("green" , "-" ), ( "orange" , "-" ) ],
79251 ylabel = "TFLOPS" ,
80252 plot_name = "attention-performance" ,
81253 args = {},
82254 )
83255 )
84- def benchmark (n , provider ):
85- d = 64
86- shape = (n , d )
256+ def benchmark (seq_len , provider ):
257+ batch_size , num_heads , emb_dim = 4 , 32 , 64
258+ shape = (batch_size , num_heads , seq_len , emb_dim )
87259 dtype = torch .float16
88260 q = torch .randn (shape , dtype = dtype , device = "cuda" )
89261 k = torch .randn (shape , dtype = dtype , device = "cuda" )
90262 v = torch .randn (shape , dtype = dtype , device = "cuda" )
91263
92264 if provider == "ninetoothed" :
93- ms = triton .testing .do_bench (
94- lambda : attention (
95- q .view (q .shape [- 2 :]), k .view (k .shape [- 2 :]), v .view (v .shape [- 2 :])
96- )
97- )
265+ ms = triton .testing .do_bench (lambda : attention (q , k , v ))
98266 elif provider == "torch" :
99267 ms = triton .testing .do_bench (
100268 lambda : F .scaled_dot_product_attention (q , k , v , scale = 1 )
101269 )
270+ elif provider == "triton" :
271+ ms = triton .testing .do_bench (lambda : triton_attention (q , k , v ))
102272
103273 def perf (ms ):
104- flops_per_matmul = 2 * n * n * d
274+ flops_per_matmul = 2 * batch_size * num_heads * seq_len * seq_len * emb_dim
105275 total_flops = 2 * flops_per_matmul
106276
107277 return total_flops * 1e-12 / (ms * 1e-3 )
0 commit comments