Skip to content

Commit 1d3712c

Browse files
committed
sliding windows logic for backwards
1 parent 8628b12 commit 1d3712c

File tree

1 file changed

+104
-85
lines changed

1 file changed

+104
-85
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 104 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -180,112 +180,121 @@ def forward_kernel_causal_and_sparse(
180180
other = 0.0
181181
)
182182

183+
q = q.reshape(BLOCK * QUERY_HEAD_GROUPS, BLOCK_HEADDIM)
184+
183185
if INCLUDE_BLOCK_CAUSAL:
184186

185-
offs_n = start_m * BLOCK + tl.arange(0, BLOCK)
187+
if SLIDING:
188+
num_kv_blocks = 2
189+
offset = -BLOCK
190+
else:
191+
num_kv_blocks = 1
192+
offset = 0
186193

187-
k_ptrs = (
188-
K +
189-
off_b * stride_kb +
190-
off_h * stride_kh +
191-
offs_n[:, None] * stride_kn +
192-
offs_d[None, :]
193-
)
194+
offs_n = start_m * BLOCK + tl.arange(0, BLOCK) + offset
194195

195-
v_ptrs = (
196-
V +
197-
off_b * stride_vb +
198-
off_h * stride_vh +
199-
offs_n[:, None] * stride_vn +
200-
offs_d[None, :]
201-
)
196+
for _ in range(num_kv_blocks):
202197

203-
if EVEN_N & EVEN_M:
204-
if EVEN_HEADDIM:
205-
k = tl.load(k_ptrs)
206-
else:
207-
k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
208-
else:
209-
if EVEN_HEADDIM:
210-
k = tl.load(
211-
k_ptrs,
212-
mask = offs_n[:, None] < seqlen_k,
213-
other = 0.0,
214-
)
215-
else:
216-
k = tl.load(
217-
k_ptrs,
218-
mask = (offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
219-
other = 0.0,
220-
)
198+
k_ptrs = (
199+
K +
200+
off_b * stride_kb +
201+
off_h * stride_kh +
202+
offs_n[:, None] * stride_kn +
203+
offs_d[None, :]
204+
)
221205

222-
qk = tl.zeros([BLOCK * QUERY_HEAD_GROUPS, BLOCK], dtype=tl.float32)
206+
v_ptrs = (
207+
V +
208+
off_b * stride_vb +
209+
off_h * stride_vh +
210+
offs_n[:, None] * stride_vn +
211+
offs_d[None, :]
212+
)
223213

224-
q = q.reshape(BLOCK * QUERY_HEAD_GROUPS, BLOCK_HEADDIM)
214+
if EVEN_N & EVEN_M:
215+
if EVEN_HEADDIM:
216+
k = tl.load(k_ptrs)
217+
else:
218+
k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
219+
else:
220+
if EVEN_HEADDIM:
221+
k = tl.load(
222+
k_ptrs,
223+
mask = offs_n[:, None] < seqlen_k,
224+
other = 0.0,
225+
)
226+
else:
227+
k = tl.load(
228+
k_ptrs,
229+
mask = (offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
230+
other = 0.0,
231+
)
225232

226-
qk += tl.dot(q, tl.trans(k))
233+
qk = tl.zeros([BLOCK * QUERY_HEAD_GROUPS, BLOCK], dtype=tl.float32)
227234

228-
qk = qk.reshape(BLOCK, QUERY_HEAD_GROUPS, BLOCK)
235+
qk += tl.dot(q, tl.trans(k))
229236

230-
if not EVEN_N:
231-
within_range_mask = offs_n[None, :] < seqlen_k
237+
qk = qk.reshape(BLOCK, QUERY_HEAD_GROUPS, BLOCK)
232238

233-
if SLIDING:
234-
within_range_mask &= offs_n[None, :] >= 0.
239+
if not EVEN_N:
240+
within_range_mask = offs_n[None, :] < seqlen_k
235241

236-
qk += tl.where(within_range_mask, 0, float("-inf"))
242+
if SLIDING:
243+
within_range_mask &= offs_n[None, :] >= 0.
237244

238-
qk = qk.reshape(BLOCK, QUERY_HEAD_GROUPS, BLOCK)
245+
qk += tl.where(within_range_mask, 0, float("-inf"))
239246

240-
causal_mask = offs_m[:, None, None] >= offs_n[None, None, :]
247+
qk = qk.reshape(BLOCK, QUERY_HEAD_GROUPS, BLOCK)
241248

242-
if SLIDING:
243-
causal_mask &= (offs_n[None, None, :] - offs_m[:, None, None]) <= BLOCK
249+
causal_mask = offs_m[:, None, None] >= offs_n[None, None, :]
244250

245-
qk += tl.where(causal_mask, 0, float("-inf"))
251+
if SLIDING:
252+
causal_mask &= (offs_n[None, None, :] - offs_m[:, None, None]) <= BLOCK
246253

247-
m_ij = tl.maximum(tl.max(qk, 2) * softmax_scale, lse_i)
248-
p = tl.exp(qk * softmax_scale - m_ij[:, :, None])
254+
qk += tl.where(causal_mask, 0, float("-inf"))
249255

250-
l_ij = tl.sum(p, 2)
256+
m_ij = tl.maximum(tl.max(qk, 2) * softmax_scale, lse_i)
257+
p = tl.exp(qk * softmax_scale - m_ij[:, :, None])
251258

252-
acc_o_scale = tl.exp(m_i - m_ij)
253-
acc_o *= acc_o_scale[:, :, None]
259+
l_ij = tl.sum(p, 2)
254260

255-
if EVEN_N & EVEN_M:
256-
if EVEN_HEADDIM:
257-
v = tl.load(v_ptrs)
258-
else:
259-
v = tl.load(
260-
v_ptrs,
261-
mask = offs_d[None, :] < headdim,
262-
other = 0.0
263-
)
264-
else:
265-
if EVEN_HEADDIM:
266-
v = tl.load(
267-
v_ptrs,
268-
mask = offs_n[:, None] < seqlen_k,
269-
other = 0.0,
270-
)
261+
acc_o_scale = tl.exp(m_i - m_ij)
262+
acc_o *= acc_o_scale[:, :, None]
263+
264+
if EVEN_N & EVEN_M:
265+
if EVEN_HEADDIM:
266+
v = tl.load(v_ptrs)
267+
else:
268+
v = tl.load(
269+
v_ptrs,
270+
mask = offs_d[None, :] < headdim,
271+
other = 0.0
272+
)
271273
else:
272-
v = tl.load(
273-
v_ptrs,
274-
mask = (offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
275-
other = 0.0,
276-
)
274+
if EVEN_HEADDIM:
275+
v = tl.load(
276+
v_ptrs,
277+
mask = offs_n[:, None] < seqlen_k,
278+
other = 0.0,
279+
)
280+
else:
281+
v = tl.load(
282+
v_ptrs,
283+
mask = (offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
284+
other = 0.0,
285+
)
277286

278-
p = p.reshape(BLOCK * QUERY_HEAD_GROUPS, BLOCK).to(v.dtype)
287+
p = p.reshape(BLOCK * QUERY_HEAD_GROUPS, BLOCK).to(v.dtype)
279288

280-
causal_o = tl.dot(p, v)
289+
causal_o = tl.dot(p, v)
281290

282-
acc_o += causal_o.reshape(BLOCK, QUERY_HEAD_GROUPS, BLOCK_HEADDIM)
291+
acc_o += causal_o.reshape(BLOCK, QUERY_HEAD_GROUPS, BLOCK_HEADDIM)
283292

284-
# -- update statistics
293+
# -- update statistics
285294

286-
m_i = m_ij
287-
l_i_new = tl.exp(lse_i - m_ij) + l_ij
288-
lse_i = m_ij + tl.log(l_i_new)
295+
m_i = m_ij
296+
l_i_new = tl.exp(lse_i - m_ij) + l_ij
297+
lse_i = m_ij + tl.log(l_i_new)
289298

290299
# # take care of the selected kv blocks
291300

@@ -1029,6 +1038,7 @@ def backward_kernel_one_col_block_causal(
10291038
BLOCK: tl.constexpr,
10301039
QUERY_HEAD_GROUPS: tl.constexpr,
10311040
QUERY_EXPAND_DIM: tl.constexpr,
1041+
SLIDING: tl.constexpr
10321042
):
10331043
# We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
10341044

@@ -1143,11 +1153,16 @@ def backward_kernel_one_col_block_causal(
11431153

11441154
qk = qk.reshape(QUERY_HEAD_GROUPS, BLOCK, BLOCK)
11451155

1156+
mask = offs_m[:, None] >= offs_n[None, :]
1157+
11461158
# Trying to combine the two masks seem to make the result wrong
11471159
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
1148-
qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
1160+
mask &= offs_n[None, :] < seqlen_k
1161+
1162+
if SLIDING:
1163+
mask &= (offs_n[None, :] - offs_m[:, None]) < BLOCK
11491164

1150-
qk = tl.where(offs_m[:, None] >= (offs_n[None, :]), qk, float("-inf"))
1165+
qk = tl.where(mask, qk, float("-inf"))
11511166

11521167
qk = qk.reshape(QUERY_HEAD_GROUPS * BLOCK, BLOCK)
11531168

@@ -1315,7 +1330,8 @@ def backward_kernel(
13151330
QUERY_HEAD_GROUPS: tl.constexpr,
13161331
QUERY_EXPAND_DIM: tl.constexpr,
13171332
RETURN_SEL_GRADS: tl.constexpr,
1318-
INCLUDE_BLOCK_CAUSAL: tl.constexpr
1333+
INCLUDE_BLOCK_CAUSAL: tl.constexpr,
1334+
SLIDING: tl.constexpr,
13191335
):
13201336
off_hb = tl.program_id(1)
13211337
off_b = off_hb // kv_heads
@@ -1393,6 +1409,7 @@ def backward_kernel(
13931409
BLOCK = BLOCK,
13941410
QUERY_HEAD_GROUPS = QUERY_HEAD_GROUPS,
13951411
QUERY_EXPAND_DIM = QUERY_EXPAND_DIM,
1412+
SLIDING = SLIDING
13961413
)
13971414
else:
13981415
for start_n in range(0, num_block_n):
@@ -1448,7 +1465,8 @@ def native_sparse_attn_backward(
14481465
dq, dk, dv,
14491466
block_size = 128,
14501467
include_block_causal = True,
1451-
return_sel_grads = False
1468+
return_sel_grads = False,
1469+
sliding = False
14521470
):
14531471
device = do.device
14541472

@@ -1563,6 +1581,7 @@ def native_sparse_attn_backward(
15631581
EVEN_HEADDIM = BLOCK_HEADDIM == dim,
15641582
RETURN_SEL_GRADS = return_sel_grads,
15651583
INCLUDE_BLOCK_CAUSAL = include_block_causal,
1584+
SLIDING = sliding
15661585
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
15671586
# num_warps=num_warps,
15681587
# num_stages=1,
@@ -1600,7 +1619,7 @@ def forward(
16001619
selected_block_indices,
16011620
fmask,
16021621
block_size = block_size,
1603-
include_block_causal = include_block_causal
1622+
include_block_causal = include_block_causal,
16041623
)
16051624

16061625
ctx.save_for_backward(fq, fk, fv, selected_block_indices, fmask, out, lse)

0 commit comments

Comments
 (0)