Skip to content

Commit f0ee01e

Browse files
harsh-nodnithinsubbiah
authored andcommitted
[Wave] Add support for sliding window attention (#626)
Signed-off-by: Harsh Menon <[email protected]> Signed-off-by: nithinsubbiah <[email protected]>
1 parent e25b3ca commit f0ee01e

File tree

4 files changed

+82
-5
lines changed

4 files changed

+82
-5
lines changed

iree/turbine/kernel/wave/templates/vanilla_attention.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,14 @@ def get_vanilla_attention_kernel(
1919
dynamic_dims: bool,
2020
is_causal: bool = False,
2121
is_v_transposed: bool = False,
22+
sliding_window_size: int = -1,
2223
):
24+
25+
if sliding_window_size > 0 and not is_causal:
26+
raise NotImplementedError(
27+
"Sliding window is only supported for causal attention."
28+
)
29+
2330
# Input sizes
2431
B = tkl.sym.B
2532
M = tkl.sym.M
@@ -78,6 +85,7 @@ def base_attention_core(q, k, v, c):
7885
c_reg = tkl.Register[B, N, M, tkl.f32](0.0)
7986
init_sum = tkl.Register[B, M, tkl.f32](0.0)
8087
init_max = tkl.Register[B, M, tkl.f32](-1e6)
88+
sliding_window = tkl.Register[M, K2, tkl.i64](sliding_window_size)
8189
ZEROF = tkl.Register[M, K2, tkl.f32](0.0)
8290
MIN_INF = tkl.Register[M, K2, tkl.f32](-1e6)
8391

@@ -106,6 +114,8 @@ def repeat(
106114
m_index = tkw.self_index(M, tkl.i64)
107115
m_index = tkw.broadcast(m_index, target_shape=[M, K2])
108116
mask = (m_index >= k2_index) & mask
117+
if sliding_window_size > 0:
118+
mask = (m_index - k2_index <= sliding_window) & mask
109119
mask = tkw.cast(mask, tkw.i1)
110120
bias = tkw.select(mask, ZEROF, MIN_INF)
111121
x_j = x_j + bias

iree/turbine/kernel/wave/utils/torch_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,7 @@ def device_randperm(*args, **kwargs):
4949

5050
def device_zeros(*args, **kwargs):
5151
return to_default_device(torch.zeros(*args, **kwargs))
52+
53+
54+
def device_ones(*args, **kwargs):
55+
return to_default_device(torch.ones(*args, **kwargs))

lit_tests/kernel/wave/attention/attention.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
)
2121
from iree.turbine.kernel.wave.scheduling.schedule import SchedulingType
2222
from iree.turbine.kernel.wave.compile import WaveCompileOptions, wave_compile
23-
import torch
2423

2524
# Input sizes
2625
B = tkl.sym.B
@@ -429,3 +428,49 @@ def test_attention_bshd():
429428
# CHECK-COUNT-8: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<4xf32>
430429
# CHECK-COUNT-8: {{.*}} = gpu.shuffle xor {{.*}}
431430
# CHECK-COUNT-8: {{.*}} = amdgpu.mfma
431+
432+
433+
@run_test
434+
def test_attention_sliding_window():
435+
shape = AttentionShape(
436+
num_query_heads=8,
437+
num_kv_heads=8,
438+
query_seq_len=128,
439+
head_size_kv=128,
440+
head_size=64,
441+
kv_seq_len=256,
442+
)
443+
mfma_variant = (tkw.MMAType.F32_16x16x16_F16,) * 2
444+
base_attention, hyperparams, _, _ = get_vanilla_attention_kernel(
445+
shape, mfma_variant, False, is_causal=True, sliding_window_size=1024
446+
)
447+
448+
options = WaveCompileOptions(
449+
subs=hyperparams,
450+
canonicalize=True,
451+
run_bench=False,
452+
schedule=SchedulingType.NONE,
453+
use_scheduling_barriers=False,
454+
compile_to_mlir=True,
455+
)
456+
base_attention = wave_compile(options, base_attention)
457+
print(base_attention.asm)
458+
459+
# CHECK-LABEL: func.func @base_attention
460+
# CHECK: %[[NEG_INF:.+]] = arith.constant dense<-1.000000e+06> : vector<4xf32>
461+
# CHECK: %[[WINDOW_SIZE:.+]] = arith.constant dense<1024> : vector<4xi64>
462+
# CHECK: %[[ZERO:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32>
463+
# CHECK: {{.*}} = scf.for
464+
# CHECK-COUNT-32: {{.*}} = amdgpu.mfma
465+
# CHECK-COUNT-4: {{.*}} = arith.cmpi slt, {{.*}} : vector<4xindex>
466+
# CHECK-COUNT-8: {{.*}} = arith.cmpi sge, {{.*}} : vector<4xi64>
467+
# CHECK-COUNT-8: {{.*}} = arith.andi {{.*}} : vector<4xi1>
468+
# This is computing the index difference: m_index - k2_index
469+
# CHECK-COUNT-8: {{.*}} = arith.subi {{.*}} : vector<4xi64>
470+
# And then comparing to the window size: m_index - k2_index <= window_size
471+
# CHECK-COUNT-8: {{.*}} = arith.cmpi sle, {{.*}}, %[[WINDOW_SIZE]] : vector<4xi64>
472+
# CHECK-COUNT-8: {{.*}} = arith.andi {{.*}} : vector<4xi1>
473+
# CHECK-COUNT-8: {{.*}} = arith.select %{{.*}}, %[[ZERO]], %[[NEG_INF]] : vector<4xi1>, vector<4xf32>
474+
# CHECK-COUNT-8: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<4xf32>
475+
# CHECK-COUNT-8: {{.*}} = gpu.shuffle xor {{.*}}
476+
# CHECK-COUNT-32: {{.*}} = amdgpu.mfma

tests/kernel/wave/attention/vanilla_attention_test.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from iree.turbine.kernel.wave.utils.torch_utils import (
2222
device_randn,
2323
device_zeros,
24+
device_ones,
2425
)
2526
from iree.turbine.kernel.wave.compile import WaveCompileOptions, wave_compile
2627
from iree.turbine.kernel.wave.constraints import MMAType
@@ -217,6 +218,7 @@ def testAttentionPure(
217218
@require_e2e
218219
@pytest.mark.parametrize("shape", get_test_shapes("all_attention"))
219220
@pytest.mark.parametrize("enable_scheduling", [SchedulingType.NONE])
221+
@pytest.mark.parametrize("sliding_window", ([-1, 1024]))
220222
@param_bool("dynamic_dims", "dyn", [False])
221223
@pytest.mark.parametrize(
222224
"mfma_variant",
@@ -228,6 +230,7 @@ def testAttentionPure(
228230
def testAttentionCausal(
229231
shape: tuple[int],
230232
enable_scheduling: SchedulingType,
233+
sliding_window: int,
231234
dynamic_dims: bool,
232235
mfma_variant: tuple[MMAType],
233236
request,
@@ -248,7 +251,12 @@ def testAttentionCausal(
248251
dynamic_symbols,
249252
dynamic_symbols_map,
250253
) = get_vanilla_attention_kernel(
251-
shape, mfma_variant, dynamic_dims, is_causal=True, is_v_transposed=True
254+
shape,
255+
mfma_variant,
256+
dynamic_dims,
257+
is_causal=True,
258+
is_v_transposed=True,
259+
sliding_window_size=sliding_window,
252260
)
253261
q_shape = (shape.num_query_heads, shape.query_seq_len, shape.head_size)
254262
k_shape = (shape.num_kv_heads, shape.kv_seq_len, shape.head_size)
@@ -284,9 +292,19 @@ def testAttentionCausal(
284292
dk_sqrt = math.sqrt(1.0 / shape.head_size)
285293
# TODO: Add scaling of QK as part of kernel.
286294
asm = base_attention(q * dk_sqrt * log2e, k, v.permute([0, 2, 1]), output)
287-
torch_ref = torch.nn.functional.scaled_dot_product_attention(
288-
q, k, v, is_causal=True
289-
)
295+
if sliding_window >= 0:
296+
297+
def sliding_window_mask(q_seq_length, kv_seq_length, window_size):
298+
mask = device_ones((q_seq_length, kv_seq_length), dtype=torch.bool)
299+
mask = mask.tril().triu(-sliding_window)
300+
return mask.to(dtype=torch.bool)
301+
302+
mask = sliding_window_mask(
303+
shape.query_seq_len, shape.kv_seq_len, sliding_window
304+
)
305+
torch_ref = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
306+
else:
307+
torch_ref = F.scaled_dot_product_attention(q, k, v, is_causal=True)
290308

291309
if dump_generated_mlir:
292310
filename = f"wave_attention_{'x'.join(map(str, shape))}.mlir"

0 commit comments

Comments
 (0)