11from math import ceil
22import torch
3- from native_sparse_attention_pytorch .triton_native_sparse_attention import native_sparse_attend , round_up_multiple , pad_to_multiple
3+
4+ from native_sparse_attention_pytorch .native_sparse_attention import (
5+ create_sliding_mask ,
6+ flex_attention
7+ )
8+
9+ from native_sparse_attention_pytorch .triton_native_sparse_attention import (
10+ native_sparse_attend ,
11+ round_up_multiple ,
12+ pad_to_multiple ,
13+ )
414
515import einx
616from einops import rearrange , einsum , repeat
1020def exists (v ):
1121 return v is not None
1222
23+ def default (v , d ):
24+ return v if exists (v ) else d
25+
1326def abs_diff (x , y ):
1427 return (x - y ).abs ().amax ()
1528
@@ -21,12 +34,22 @@ def regular_attend(
2134 indices ,
2235 mask ,
2336 block_size ,
37+ sliding_window_size = None ,
2438 sel_scale = None ,
25- return_lse = False
39+ return_lse = False ,
40+ return_sliding_window_out = False
2641):
2742 q_heads , seq_len , kv_heads , device = q .shape [1 ], q .shape [- 2 ], k .shape [1 ], q .device
2843 assert divisible_by (q_heads , kv_heads )
2944
45+ if return_sliding_window_out :
46+ kv_seq_len = k .shape [- 2 ]
47+ assert seq_len == kv_seq_len
48+
49+ sliding_window_size = default (sliding_window_size , block_size )
50+ sliding_mask = create_sliding_mask (kv_seq_len , sliding_window_size )
51+ sliding_out = flex_attention (q , k , v , block_mask = sliding_mask , enable_gqa = True )
52+
3053 q , k , v = tuple (pad_to_multiple (t , block_size , dim = - 2 ) for t in (q , k , v ))
3154
3255 if exists (sel_scale ):
@@ -97,6 +120,9 @@ def regular_attend(
97120
98121 out = out [..., :seq_len , :]
99122
123+ if return_sliding_window_out :
124+ out = (out , sliding_out )
125+
100126 if not return_lse :
101127 return out
102128
@@ -114,6 +140,7 @@ def regular_attend(
114140kv_heads = 2
115141fine_block_size = 16
116142num_sel = 6
143+ fused_sliding_window = False
117144
118145q = torch .randn (batch , q_heads , seq_len , 64 ).cuda ()
119146k = torch .randn (batch , kv_heads , seq_len , 64 ).cuda ()
@@ -130,7 +157,11 @@ def regular_attend(
130157
131158# regular forwards and backwards
132159
133- out , rlse = regular_attend (rq , rk , rv , indices , mask , block_size = fine_block_size , sel_scale = rsel_scale , return_lse = True )
160+ out , rlse = regular_attend (rq , rk , rv , indices , mask , block_size = fine_block_size , sel_scale = rsel_scale , return_lse = True , return_sliding_window_out = fused_sliding_window )
161+
162+ if fused_sliding_window :
163+ out = sum (out )
164+
134165out .sum ().backward ()
135166
136167# triton nsa forwards and backwards
0 commit comments