@@ -116,6 +116,12 @@ def parse_op_args(args: List[str]):
116
116
action = "store_true" ,
117
117
help = "enable causal" ,
118
118
)
119
+ parser .add_argument (
120
+ "--window-size" ,
121
+ type = lambda x : tuple (map (int , x .split ("," ))),
122
+ default = (- 1 , - 1 ),
123
+ help = "sliding window size as (left_window, right_window). Use (-1, -1) to disable sliding window" ,
124
+ )
119
125
parser .add_argument (
120
126
"--native-sdpa" , action = "store_true" , help = "Use SDPA native choice."
121
127
)
@@ -177,6 +183,13 @@ def __init__(
177
183
self .H = args .n_heads
178
184
self .D_HEAD = args .d_head
179
185
self .causal = args .causal
186
+ self .window_size = args .window_size
187
+ self .local = self .window_size != (- 1 , - 1 )
188
+
189
+ # Prioritize sliding window over causal when both are specified
190
+ if self .causal and self .local :
191
+ self .causal = False
192
+
180
193
self .native_sdpa = args .native_sdpa
181
194
self .pt2_sdpa = args .pt2_sdpa
182
195
self .input_types = args .input_types
@@ -191,10 +204,22 @@ def aten(
191
204
) -> Callable :
192
205
def _inner ():
193
206
N_CTX = q .shape [2 ]
194
- M = torch . tril ( torch . ones (( N_CTX , N_CTX ), device = self . device ))
207
+ N_CTX_KV = k . shape [ 2 ]
195
208
p = torch .matmul (q , k .transpose (2 , 3 )) * self .sm_scale
209
+
196
210
if self .causal :
211
+ M = torch .tril (torch .ones ((N_CTX , N_CTX_KV ), device = self .device ))
197
212
p [:, :, M == 0 ] = float ("-inf" )
213
+ elif self .local :
214
+ # Create sliding window mask
215
+ i = torch .arange (N_CTX , device = self .device ).unsqueeze (1 )
216
+ j = torch .arange (N_CTX_KV , device = self .device ).unsqueeze (0 )
217
+ # Allow attention if within window (both left and right)
218
+ left_window , right_window = self .window_size
219
+ window_mask = (i - j ) <= left_window & ((j - i ) <= right_window )
220
+ # Note: causal is already handled separately above and should not be true when sliding_window is true
221
+ p [:, :, ~ window_mask ] = float ("-inf" )
222
+
198
223
p = torch .softmax (p .float (), dim = - 1 ).to (q .dtype )
199
224
# p = torch.exp(p)
200
225
ref_out = torch .matmul (p , v )
@@ -209,6 +234,10 @@ def sdpa(
209
234
k : torch .Tensor ,
210
235
v : torch .Tensor ,
211
236
) -> Callable :
237
+ if self .local :
238
+ # sdpa with flash attention backend doesn't support non-null attn_mask
239
+ raise NotImplementedError ("Skip" )
240
+
212
241
def sdpa_flash_attention (q , k , v ):
213
242
cxt = (
214
243
nullcontext ()
@@ -249,7 +278,10 @@ def flash_v2(
249
278
) -> Callable :
250
279
qkv = make_packed_qkv (q , k , v )
251
280
fn = lambda : flash_attn_func (
252
- qkv , softmax_scale = self .sm_scale , causal = self .causal
281
+ qkv ,
282
+ softmax_scale = self .sm_scale ,
283
+ causal = self .causal ,
284
+ window_size = self .window_size ,
253
285
)
254
286
return fn
255
287
@@ -264,7 +296,17 @@ def xformers_preprocess(
264
296
q_1 = q_1 .contiguous ()
265
297
k_1 = k_1 .contiguous ()
266
298
v_1 = v_1 .contiguous ()
267
- attn_bias = xformers .ops .LowerTriangularMask () if self .causal else None
299
+
300
+ # Create attention bias based on settings
301
+ attn_bias = None
302
+ if self .causal :
303
+ attn_bias = xformers .ops .LowerTriangularMask ()
304
+ elif self .local :
305
+ attn_bias = xformers .ops .fmha .attn_bias .LocalAttentionFromBottomRightMask (
306
+ window_left = self .window_size [0 ],
307
+ window_right = self .window_size [1 ],
308
+ )
309
+
268
310
fhma_input = xformers_fmha .Inputs (
269
311
query = q_1 , key = k_1 , value = v_1 , attn_bias = attn_bias , scale = self .sm_scale
270
312
)
@@ -291,6 +333,9 @@ def xformers_splitk(
291
333
k : torch .Tensor ,
292
334
v : torch .Tensor ,
293
335
):
336
+ if self .local or self .causal :
337
+ # SplitK doesn't support local attention yet
338
+ raise NotImplementedError ("Skip" )
294
339
need_gradient = not (self .mode == BenchmarkMode .FWD_NO_GRAD )
295
340
fhma_input = self .xformers_preprocess (q , k , v )
296
341
xformers_splitk_fhma = xformers_fmha .triton_splitk .FwOp
@@ -303,6 +348,10 @@ def xformers_splitk(
303
348
label = f"cudnn-sdpa-{ torch .backends .cudnn .version ()} " ,
304
349
)
305
350
def cudnn_sdpa (self , q , k , v ):
351
+ if self .local :
352
+ # Skip CUDNN SDPA for local attention for now
353
+ raise NotImplementedError ("Skip" )
354
+
306
355
return lambda : _sdpa_cudnn_attention (
307
356
q , k , v , is_causal = self .causal , scale = self .sm_scale
308
357
)
@@ -318,7 +367,12 @@ def cutedsl_blackwell(
318
367
k = k .transpose (1 , 2 ).contiguous ()
319
368
v = v .transpose (1 , 2 ).contiguous ()
320
369
return lambda : facute_flash_attn_func (
321
- q , k , v , softmax_scale = self .sm_scale , causal = self .causal
370
+ q ,
371
+ k ,
372
+ v ,
373
+ softmax_scale = self .sm_scale ,
374
+ causal = self .causal ,
375
+ window_size = self .window_size if self .local else (None , None ),
322
376
)
323
377
324
378
@register_benchmark ()
@@ -328,12 +382,27 @@ def flex_attention(self, q, k, v):
328
382
def causal_mask (b , h , q_idx , kv_idx ):
329
383
return q_idx >= kv_idx
330
384
385
+ def local_mask (b , h , q_idx , kv_idx ):
386
+ # Left window check: allow tokens within left_window_size lookback
387
+ left_ok = q_idx - kv_idx <= self .window_size [0 ]
388
+ # Right window check: allow tokens within right_window_size lookahead
389
+ right_ok = kv_idx - q_idx <= self .window_size [1 ]
390
+ return left_ok & right_ok
391
+
331
392
flex_attention = torch .compile (flex_attention , dynamic = False )
332
393
394
+ B , H , S , D = q .shape
395
+ _ , _ , S_KV , _ = k .shape
396
+
397
+ mask_mod = None
333
398
if self .causal :
334
- B , H , S , D = q .shape
399
+ mask_mod = causal_mask
400
+ elif self .local :
401
+ mask_mod = local_mask
402
+
403
+ if mask_mod :
335
404
block_mask = create_block_mask (
336
- causal_mask , B = None , H = None , Q_LEN = S , KV_LEN = S
405
+ mask_mod , B = None , H = None , Q_LEN = S , KV_LEN = S_KV
337
406
)
338
407
else :
339
408
block_mask = None
@@ -391,10 +460,24 @@ def flops(
391
460
q , k , v = example_inputs
392
461
BATCH , H , N_CTX , D_HEAD = q .shape
393
462
_ , _ , N_CTX_KV , _ = k .shape
394
- flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX_KV * D_HEAD
395
- flops = 2 * flops_per_matmul
396
- if self .causal :
397
- flops *= 0.5
463
+
464
+ if not self .local :
465
+ flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX_KV * D_HEAD
466
+ flops = 2 * flops_per_matmul
467
+ if self .causal :
468
+ flops *= 0.5
469
+ else :
470
+ row_idx = torch .arange (N_CTX , device = "cuda" )
471
+ col_left = torch .maximum (
472
+ row_idx + N_CTX_KV - N_CTX - self .window_size [0 ], torch .tensor (0 )
473
+ )
474
+ col_right = torch .minimum (
475
+ row_idx + N_CTX_KV - N_CTX + self .window_size [1 ],
476
+ torch .tensor (N_CTX_KV - 1 ),
477
+ )
478
+ avg_seqlen = (col_right - col_left + 1 ).float ().mean ().item ()
479
+ flops = 2 * 2.0 * BATCH * H * N_CTX * avg_seqlen * D_HEAD
480
+
398
481
if self .mode == BenchmarkMode .BWD :
399
482
flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
400
483
elif self .mode == BenchmarkMode .FWD_BWD :
@@ -436,8 +519,15 @@ def get_input_iter(self) -> Generator:
436
519
raise AssertionError (f"Unknown input type { self .input_types } " )
437
520
438
521
@register_x_val (label = "(Batch, Heads, Heads_KV, SeqLen, SeqLen_KV, Dhead)" )
439
- def get_x_val (self , example_inputs ) -> float :
522
+ def get_x_val (self , example_inputs ) -> str :
440
523
q , k , v = example_inputs
441
524
B , H , S , D = q .shape
442
525
_ , H_KV , S_KV , _ = k .shape
443
- return (B , H , H_KV , S , S_KV , D )
526
+
527
+ # Add local mask info to the label if enabled
528
+ base_info = f"({ B } , { H } , { H_KV } , { S } , { S_KV } , { D } )"
529
+ if self .local :
530
+ base_info += f" Local { self .window_size [0 ]} ,{ self .window_size [1 ]} "
531
+ if self .causal :
532
+ base_info += " Causal"
533
+ return base_info
0 commit comments