@@ -44,6 +44,7 @@ def _attn_fwd_subtile(
44
44
dtype : tl .constexpr ,
45
45
STAGE : tl .constexpr ,
46
46
SUBTILING : tl .constexpr ,
47
+ VECT_MUL : tl .constexpr ,
47
48
):
48
49
qk = tl .dot (q , k )
49
50
if STAGE == 2 :
@@ -53,7 +54,10 @@ def _attn_fwd_subtile(
53
54
qk -= m_ij [:, None ]
54
55
else :
55
56
m_ij = tl .maximum (m_i , tl .max (qk , 1 ) * qk_scale )
56
- qk = qk * qk_scale - m_ij [:, None ]
57
+ if VECT_MUL :
58
+ qk = _fma_f32x2 (qk , qk_scale , - m_ij [:, None ])
59
+ else :
60
+ qk = qk * qk_scale - m_ij [:, None ]
57
61
p = tl .math .exp2 (qk )
58
62
# -- compute correction factor
59
63
alpha = tl .math .exp2 (m_i - m_ij )
@@ -65,8 +69,12 @@ def _attn_fwd_subtile(
65
69
66
70
if SUBTILING :
67
71
acc0 , acc1 = acc .reshape ([BM , 2 , BN // 2 ]).permute (0 , 2 , 1 ).split ()
68
- acc0 = acc0 * alpha [:, None ]
69
- acc1 = acc1 * alpha [:, None ]
72
+ if VECT_MUL :
73
+ acc0 = _mul_f32x2 (acc0 , alpha [:, None ])
74
+ acc1 = _mul_f32x2 (acc1 , alpha [:, None ])
75
+ else :
76
+ acc0 = acc0 * alpha [:, None ]
77
+ acc1 = acc1 * alpha [:, None ]
70
78
acc = tl .join (acc0 , acc1 ).permute (0 , 2 , 1 ).reshape ([BM , BN ])
71
79
else :
72
80
acc = acc * alpha [:, None ]
@@ -109,6 +117,7 @@ def _attn_fwd_inner_oss_dp(
109
117
N_CTX : tl .constexpr ,
110
118
warp_specialize : tl .constexpr ,
111
119
SUBTILING : tl .constexpr ,
120
+ VECT_MUL : tl .constexpr ,
112
121
):
113
122
# range of values handled by this stage
114
123
if STAGE == 1 :
@@ -144,6 +153,7 @@ def _attn_fwd_inner_oss_dp(
144
153
dtype ,
145
154
STAGE ,
146
155
SUBTILING ,
156
+ VECT_MUL ,
147
157
)
148
158
l_i1 , m_i1 , acc1 = _attn_fwd_subtile (
149
159
q1 ,
@@ -159,6 +169,7 @@ def _attn_fwd_inner_oss_dp(
159
169
dtype ,
160
170
STAGE ,
161
171
SUBTILING ,
172
+ VECT_MUL ,
162
173
)
163
174
164
175
offsetkv_y += BLOCK_N
@@ -191,18 +202,25 @@ def _host_descriptor_pre_hook(nargs):
191
202
if is_tile_enabled ():
192
203
configs = [
193
204
triton .Config (
194
- {"BLOCK_M" : BM , "BLOCK_N" : BN , "occupancy" : occ , "SUBTILING" : subtile },
205
+ {
206
+ "BLOCK_M" : BM ,
207
+ "BLOCK_N" : BN ,
208
+ "occupancy" : occ ,
209
+ "SUBTILING" : subtile ,
210
+ "VECT_MUL" : vectmul ,
211
+ },
195
212
pre_hook = _host_descriptor_pre_hook ,
196
213
)
197
214
for BM in [64 , 128 , 256 ]
198
215
for BN in [64 , 128 ]
199
216
for occ in [1 , 2 ]
200
217
for subtile in [True ]
218
+ for vectmul in [True ]
201
219
]
202
220
else :
203
221
configs = [
204
222
triton .Config (
205
- {"BLOCK_M" : BM , "BLOCK_N" : BN , "SUBTILING" : subtile },
223
+ {"BLOCK_M" : BM , "BLOCK_N" : BN , "SUBTILING" : subtile , "VECT_MUL" : vectmul },
206
224
num_stages = s ,
207
225
num_warps = w ,
208
226
pre_hook = _host_descriptor_pre_hook ,
@@ -212,7 +230,8 @@ def _host_descriptor_pre_hook(nargs):
212
230
for BN in [128 ]
213
231
for s in NUM_STAGES_OPTIONS
214
232
for w in [4 ]
215
- for subtile in [False ] # disable subtiling for now
233
+ for subtile in [True ]
234
+ for vectmul in [False ]
216
235
]
217
236
218
237
@@ -242,6 +261,47 @@ def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape):
242
261
return tl .make_tensor_descriptor (desc_or_ptr , shape , strides , block_shape )
243
262
244
263
264
+ @triton .jit
265
+ def _mul_f32x2 (a , b ):
266
+ return tl .inline_asm_elementwise (
267
+ """
268
+ {
269
+ .reg .b64 ra, rb, rc;
270
+ mov.b64 ra, { $2, $3 };
271
+ mov.b64 rb, { $4, $5 };
272
+ mul.f32x2 rc, ra, rb;
273
+ mov.b64 { $0, $1 }, rc;
274
+ }
275
+ """ ,
276
+ "=r,=r,r,r,r,r" ,
277
+ [a , b ],
278
+ dtype = tl .float32 ,
279
+ is_pure = True ,
280
+ pack = 2 ,
281
+ )
282
+
283
+
284
+ @triton .jit
285
+ def _fma_f32x2 (a , b , c ):
286
+ return tl .inline_asm_elementwise (
287
+ """
288
+ {
289
+ .reg .b64 ra, rb, rc, rd;
290
+ mov.b64 ra, { $2, $3 };
291
+ mov.b64 rb, { $4, $5 };
292
+ mov.b64 rc, { $6, $7 };
293
+ fma.rn.f32x2 rd, ra, rb, rc;
294
+ mov.b64 { $0, $1 }, rd;
295
+ }
296
+ """ ,
297
+ "=r,=r,r,r,r,r,r,r" ,
298
+ [a , b , c ],
299
+ dtype = tl .float32 ,
300
+ is_pure = True ,
301
+ pack = 2 ,
302
+ )
303
+
304
+
245
305
@triton .jit
246
306
def _attn_fwd_tma_dp (
247
307
sm_scale ,
@@ -263,6 +323,7 @@ def _attn_fwd_tma_dp(
263
323
warp_specialize : tl .constexpr , #
264
324
dtype : tl .constexpr ,
265
325
SUBTILING : tl .constexpr ,
326
+ VECT_MUL : tl .constexpr ,
266
327
):
267
328
tl .static_assert (BLOCK_N <= HEAD_DIM )
268
329
start_m = pid # tl.program_id(0)
@@ -317,6 +378,7 @@ def _attn_fwd_tma_dp(
317
378
N_CTX , #
318
379
warp_specialize ,
319
380
SUBTILING ,
381
+ VECT_MUL ,
320
382
)
321
383
if STAGE & 2 :
322
384
acc0 , acc1 , l_i0 , l_i1 , m_i0 , m_i1 = _attn_fwd_inner_oss_dp (
@@ -344,6 +406,7 @@ def _attn_fwd_tma_dp(
344
406
N_CTX , #
345
407
warp_specialize ,
346
408
SUBTILING ,
409
+ VECT_MUL ,
347
410
)
348
411
349
412
m_i0 += tl .math .log2 (l_i0 )
@@ -383,6 +446,7 @@ def _attn_fwd(
383
446
warp_specialize : tl .constexpr , #
384
447
dtype : tl .constexpr ,
385
448
SUBTILING : tl .constexpr ,
449
+ VECT_MUL : tl .constexpr ,
386
450
):
387
451
pid = tl .program_id (0 )
388
452
off_hz = tl .program_id (1 )
@@ -406,6 +470,7 @@ def _attn_fwd(
406
470
warp_specialize ,
407
471
dtype ,
408
472
SUBTILING ,
473
+ VECT_MUL ,
409
474
)
410
475
411
476
@@ -434,6 +499,7 @@ def _attn_fwd_persist(
434
499
OUTER_LOOP : tl .constexpr ,
435
500
dtype : tl .constexpr ,
436
501
SUBTILING : tl .constexpr ,
502
+ VECT_MUL : tl .constexpr ,
437
503
):
438
504
n_tile_num = tl .cdiv (N_CTX , BLOCK_M )
439
505
prog_id = tl .program_id (0 )
@@ -469,6 +535,7 @@ def _attn_fwd_persist(
469
535
warp_specialize and not OUTER_LOOP ,
470
536
dtype ,
471
537
SUBTILING ,
538
+ VECT_MUL ,
472
539
)
473
540
tile_idx += num_progs
474
541
0 commit comments