Skip to content

Commit 30c6f6d

Browse files
authored
adding swa to tritonbench
Differential Revision: D82179875 Pull Request resolved: #415
1 parent 06c0519 commit 30c6f6d

File tree

1 file changed

+102
-12
lines changed

1 file changed

+102
-12
lines changed

tritonbench/operators/blackwell_attentions/operator.py

Lines changed: 102 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,12 @@ def parse_op_args(args: List[str]):
116116
action="store_true",
117117
help="enable causal",
118118
)
119+
parser.add_argument(
120+
"--window-size",
121+
type=lambda x: tuple(map(int, x.split(","))),
122+
default=(-1, -1),
123+
help="sliding window size as (left_window, right_window). Use (-1, -1) to disable sliding window",
124+
)
119125
parser.add_argument(
120126
"--native-sdpa", action="store_true", help="Use SDPA native choice."
121127
)
@@ -177,6 +183,13 @@ def __init__(
177183
self.H = args.n_heads
178184
self.D_HEAD = args.d_head
179185
self.causal = args.causal
186+
self.window_size = args.window_size
187+
self.local = self.window_size != (-1, -1)
188+
189+
# Prioritize sliding window over causal when both are specified
190+
if self.causal and self.local:
191+
self.causal = False
192+
180193
self.native_sdpa = args.native_sdpa
181194
self.pt2_sdpa = args.pt2_sdpa
182195
self.input_types = args.input_types
@@ -191,10 +204,22 @@ def aten(
191204
) -> Callable:
192205
def _inner():
193206
N_CTX = q.shape[2]
194-
M = torch.tril(torch.ones((N_CTX, N_CTX), device=self.device))
207+
N_CTX_KV = k.shape[2]
195208
p = torch.matmul(q, k.transpose(2, 3)) * self.sm_scale
209+
196210
if self.causal:
211+
M = torch.tril(torch.ones((N_CTX, N_CTX_KV), device=self.device))
197212
p[:, :, M == 0] = float("-inf")
213+
elif self.local:
214+
# Create sliding window mask
215+
i = torch.arange(N_CTX, device=self.device).unsqueeze(1)
216+
j = torch.arange(N_CTX_KV, device=self.device).unsqueeze(0)
217+
# Allow attention if within window (both left and right)
218+
left_window, right_window = self.window_size
219+
window_mask = (i - j) <= left_window & ((j - i) <= right_window)
220+
# Note: causal is already handled separately above and should not be true when sliding_window is true
221+
p[:, :, ~window_mask] = float("-inf")
222+
198223
p = torch.softmax(p.float(), dim=-1).to(q.dtype)
199224
# p = torch.exp(p)
200225
ref_out = torch.matmul(p, v)
@@ -209,6 +234,10 @@ def sdpa(
209234
k: torch.Tensor,
210235
v: torch.Tensor,
211236
) -> Callable:
237+
if self.local:
238+
# sdpa with flash attention backend doesn't support non-null attn_mask
239+
raise NotImplementedError("Skip")
240+
212241
def sdpa_flash_attention(q, k, v):
213242
cxt = (
214243
nullcontext()
@@ -249,7 +278,10 @@ def flash_v2(
249278
) -> Callable:
250279
qkv = make_packed_qkv(q, k, v)
251280
fn = lambda: flash_attn_func(
252-
qkv, softmax_scale=self.sm_scale, causal=self.causal
281+
qkv,
282+
softmax_scale=self.sm_scale,
283+
causal=self.causal,
284+
window_size=self.window_size,
253285
)
254286
return fn
255287

@@ -264,7 +296,17 @@ def xformers_preprocess(
264296
q_1 = q_1.contiguous()
265297
k_1 = k_1.contiguous()
266298
v_1 = v_1.contiguous()
267-
attn_bias = xformers.ops.LowerTriangularMask() if self.causal else None
299+
300+
# Create attention bias based on settings
301+
attn_bias = None
302+
if self.causal:
303+
attn_bias = xformers.ops.LowerTriangularMask()
304+
elif self.local:
305+
attn_bias = xformers.ops.fmha.attn_bias.LocalAttentionFromBottomRightMask(
306+
window_left=self.window_size[0],
307+
window_right=self.window_size[1],
308+
)
309+
268310
fhma_input = xformers_fmha.Inputs(
269311
query=q_1, key=k_1, value=v_1, attn_bias=attn_bias, scale=self.sm_scale
270312
)
@@ -291,6 +333,9 @@ def xformers_splitk(
291333
k: torch.Tensor,
292334
v: torch.Tensor,
293335
):
336+
if self.local or self.causal:
337+
# SplitK doesn't support local attention yet
338+
raise NotImplementedError("Skip")
294339
need_gradient = not (self.mode == BenchmarkMode.FWD_NO_GRAD)
295340
fhma_input = self.xformers_preprocess(q, k, v)
296341
xformers_splitk_fhma = xformers_fmha.triton_splitk.FwOp
@@ -303,6 +348,10 @@ def xformers_splitk(
303348
label=f"cudnn-sdpa-{torch.backends.cudnn.version()}",
304349
)
305350
def cudnn_sdpa(self, q, k, v):
351+
if self.local:
352+
# Skip CUDNN SDPA for local attention for now
353+
raise NotImplementedError("Skip")
354+
306355
return lambda: _sdpa_cudnn_attention(
307356
q, k, v, is_causal=self.causal, scale=self.sm_scale
308357
)
@@ -318,7 +367,12 @@ def cutedsl_blackwell(
318367
k = k.transpose(1, 2).contiguous()
319368
v = v.transpose(1, 2).contiguous()
320369
return lambda: facute_flash_attn_func(
321-
q, k, v, softmax_scale=self.sm_scale, causal=self.causal
370+
q,
371+
k,
372+
v,
373+
softmax_scale=self.sm_scale,
374+
causal=self.causal,
375+
window_size=self.window_size if self.local else (None, None),
322376
)
323377

324378
@register_benchmark()
@@ -328,12 +382,27 @@ def flex_attention(self, q, k, v):
328382
def causal_mask(b, h, q_idx, kv_idx):
329383
return q_idx >= kv_idx
330384

385+
def local_mask(b, h, q_idx, kv_idx):
386+
# Left window check: allow tokens within left_window_size lookback
387+
left_ok = q_idx - kv_idx <= self.window_size[0]
388+
# Right window check: allow tokens within right_window_size lookahead
389+
right_ok = kv_idx - q_idx <= self.window_size[1]
390+
return left_ok & right_ok
391+
331392
flex_attention = torch.compile(flex_attention, dynamic=False)
332393

394+
B, H, S, D = q.shape
395+
_, _, S_KV, _ = k.shape
396+
397+
mask_mod = None
333398
if self.causal:
334-
B, H, S, D = q.shape
399+
mask_mod = causal_mask
400+
elif self.local:
401+
mask_mod = local_mask
402+
403+
if mask_mod:
335404
block_mask = create_block_mask(
336-
causal_mask, B=None, H=None, Q_LEN=S, KV_LEN=S
405+
mask_mod, B=None, H=None, Q_LEN=S, KV_LEN=S_KV
337406
)
338407
else:
339408
block_mask = None
@@ -391,10 +460,24 @@ def flops(
391460
q, k, v = example_inputs
392461
BATCH, H, N_CTX, D_HEAD = q.shape
393462
_, _, N_CTX_KV, _ = k.shape
394-
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX_KV * D_HEAD
395-
flops = 2 * flops_per_matmul
396-
if self.causal:
397-
flops *= 0.5
463+
464+
if not self.local:
465+
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX_KV * D_HEAD
466+
flops = 2 * flops_per_matmul
467+
if self.causal:
468+
flops *= 0.5
469+
else:
470+
row_idx = torch.arange(N_CTX, device="cuda")
471+
col_left = torch.maximum(
472+
row_idx + N_CTX_KV - N_CTX - self.window_size[0], torch.tensor(0)
473+
)
474+
col_right = torch.minimum(
475+
row_idx + N_CTX_KV - N_CTX + self.window_size[1],
476+
torch.tensor(N_CTX_KV - 1),
477+
)
478+
avg_seqlen = (col_right - col_left + 1).float().mean().item()
479+
flops = 2 * 2.0 * BATCH * H * N_CTX * avg_seqlen * D_HEAD
480+
398481
if self.mode == BenchmarkMode.BWD:
399482
flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
400483
elif self.mode == BenchmarkMode.FWD_BWD:
@@ -436,8 +519,15 @@ def get_input_iter(self) -> Generator:
436519
raise AssertionError(f"Unknown input type {self.input_types}")
437520

438521
@register_x_val(label="(Batch, Heads, Heads_KV, SeqLen, SeqLen_KV, Dhead)")
439-
def get_x_val(self, example_inputs) -> float:
522+
def get_x_val(self, example_inputs) -> str:
440523
q, k, v = example_inputs
441524
B, H, S, D = q.shape
442525
_, H_KV, S_KV, _ = k.shape
443-
return (B, H, H_KV, S, S_KV, D)
526+
527+
# Add local mask info to the label if enabled
528+
base_info = f"({B}, {H}, {H_KV}, {S}, {S_KV}, {D})"
529+
if self.local:
530+
base_info += f" Local {self.window_size[0]},{self.window_size[1]}"
531+
if self.causal:
532+
base_info += " Causal"
533+
return base_info

0 commit comments

Comments
 (0)