Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions native_sparse_attention/ops/triton/topk_sparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,15 +548,31 @@ def backward_dkdv(
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
if BLOCK_SIZE_K * pid_k >= k_len:
return
# get topk_q_idx
b_start = tl.load(cu_seqblocks + pid_b) # how many blocks before current sequence

# Check if we're accessing valid blocks for this batch
b_start = tl.load(cu_seqblocks + pid_b)
b_end = tl.load(cu_seqblocks + pid_b + 1)
total_blocks_in_batch = b_end - b_start

if pid_k >= total_blocks_in_batch:
return

# Ensure we don't access beyond cu_topk_q_count bounds
if (b_start + pid_k + 1) >= (b_end):
return

act_q_start = tl.load(
cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k) * stride_ctqn
)
act_q_end = tl.load(
cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k + 1) * stride_ctqn
)
act_q_len = act_q_end - act_q_start

# bounds check for act_q_len
if act_q_len <= 0:
return

tq_ptr = tq_ptr + pid_kh * stride_tqh + act_q_start * stride_tqn
# init pointers
k_ptrs = tl.make_block_ptr(
Expand Down Expand Up @@ -612,49 +628,57 @@ def backward_dkdv(
lse_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh
# loop for q blocks
for i in range(0, act_q_len, BLOCK_SIZE_Q):
# load
idx_q = tl.load(tq_ptr + i + off_q, mask=off_q < act_q_len - i, other=0).to(
tl.int32
)
# load with enhanced bounds checking
load_mask = off_q < act_q_len - i
idx_q = tl.load(tq_ptr + i + off_q, mask=load_mask, other=-1).to(tl.int32)

# Sanitize idx_q to prevent invalid memory access
idx_q = tl.where(load_mask, idx_q, 0) # Replace invalid indices with 0
idx_q = tl.maximum(idx_q, 0) # Ensure non-negative
idx_q = tl.minimum(idx_q, q_len - 1) # Ensure within bounds

# Create enhanced masks for memory access
valid_q_mask = load_mask & (idx_q >= 0) & (idx_q < q_len)

q = tl.load(
q_ptrs + idx_q[:, None] * stride_qn,
mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :],
mask=valid_q_mask[:, None] & (off_d < HEAD_DIM)[None, :],
other=0,
)
)
do = tl.load(
do_ptrs + idx_q[:, None] * stride_don,
mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :],
mask=valid_q_mask[:, None] & (off_d < HEAD_DIM)[None, :],
other=0,
)
)
lse = tl.load(
lse_ptrs + idx_q[:, None] * stride_ln,
mask=(off_q < act_q_len - i)[:, None],
lse_ptrs + idx_q * stride_ln,
mask=valid_q_mask,
other=0,
)
)
d = tl.load(
d_ptrs + idx_q[:, None] * stride_dn,
mask=(off_q < act_q_len - i)[:, None],
d_ptrs + idx_q * stride_dn,
mask=valid_q_mask,
other=0,
)
)
# compute qk
qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
qk += tl.where(idx_q[:, None] >= off_k[None, :], float(0.0), float("-inf"))
qk += tl.dot(q, k.T) * qk_scale
# compute p, ds
p = tl.exp2(qk - lse)
p = tl.exp2(qk - lse[:, None])
dp = tl.dot(do, v.T)
ds = sm_scale * p * (dp - d)
ds = sm_scale * p * (dp - d[:, None])
# cast dtype
p = p.to(do.dtype)
ds = ds.to(q.dtype)
# update dk and dv
dk += tl.dot(ds.T, q)
dv += tl.dot(p.T, do)
# update dk and dv (only for valid queries)
valid_mask_2d = valid_q_mask[:, None]
dk += tl.dot(tl.where(valid_mask_2d, ds, 0.0).T, q)
dv += tl.dot(tl.where(valid_mask_2d, p, 0.0).T, do)
# save dk dv
tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))
tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))


@triton.jit
def backward_dq(
q_ptr, # Q: n x qh x d
Expand Down