diff --git a/native_sparse_attention/ops/triton/topk_sparse_attention.py b/native_sparse_attention/ops/triton/topk_sparse_attention.py index af5bc64..1193db3 100644 --- a/native_sparse_attention/ops/triton/topk_sparse_attention.py +++ b/native_sparse_attention/ops/triton/topk_sparse_attention.py @@ -548,8 +548,19 @@ def backward_dkdv( k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start if BLOCK_SIZE_K * pid_k >= k_len: return - # get topk_q_idx - b_start = tl.load(cu_seqblocks + pid_b) # how many blocks before current sequence + + # Check if we're accessing valid blocks for this batch + b_start = tl.load(cu_seqblocks + pid_b) + b_end = tl.load(cu_seqblocks + pid_b + 1) + total_blocks_in_batch = b_end - b_start + + if pid_k >= total_blocks_in_batch: + return + + # Ensure we don't access beyond cu_topk_q_count bounds + if (b_start + pid_k + 1) >= (b_end): + return + act_q_start = tl.load( cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k) * stride_ctqn ) @@ -557,6 +568,11 @@ def backward_dkdv( cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k + 1) * stride_ctqn ) act_q_len = act_q_end - act_q_start + + # bounds check for act_q_len + if act_q_len <= 0: + return + tq_ptr = tq_ptr + pid_kh * stride_tqh + act_q_start * stride_tqn # init pointers k_ptrs = tl.make_block_ptr( @@ -612,49 +628,57 @@ def backward_dkdv( lse_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh # loop for q blocks for i in range(0, act_q_len, BLOCK_SIZE_Q): - # load - idx_q = tl.load(tq_ptr + i + off_q, mask=off_q < act_q_len - i, other=0).to( - tl.int32 - ) + # load with enhanced bounds checking + load_mask = off_q < act_q_len - i + idx_q = tl.load(tq_ptr + i + off_q, mask=load_mask, other=-1).to(tl.int32) + + # Sanitize idx_q to prevent invalid memory access + idx_q = tl.where(load_mask, idx_q, 0) # Replace invalid indices with 0 + idx_q = tl.maximum(idx_q, 0) # Ensure non-negative + idx_q = tl.minimum(idx_q, q_len - 1) # Ensure within bounds + + # Create enhanced masks for memory access + valid_q_mask = load_mask & (idx_q >= 0) & (idx_q < q_len) + q = tl.load( q_ptrs + idx_q[:, None] * stride_qn, - mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], + mask=valid_q_mask[:, None] & (off_d < HEAD_DIM)[None, :], other=0, - ) + ) do = tl.load( do_ptrs + idx_q[:, None] * stride_don, - mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], + mask=valid_q_mask[:, None] & (off_d < HEAD_DIM)[None, :], other=0, - ) + ) lse = tl.load( - lse_ptrs + idx_q[:, None] * stride_ln, - mask=(off_q < act_q_len - i)[:, None], + lse_ptrs + idx_q * stride_ln, + mask=valid_q_mask, other=0, - ) + ) d = tl.load( - d_ptrs + idx_q[:, None] * stride_dn, - mask=(off_q < act_q_len - i)[:, None], + d_ptrs + idx_q * stride_dn, + mask=valid_q_mask, other=0, - ) + ) # compute qk qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) qk += tl.where(idx_q[:, None] >= off_k[None, :], float(0.0), float("-inf")) qk += tl.dot(q, k.T) * qk_scale # compute p, ds - p = tl.exp2(qk - lse) + p = tl.exp2(qk - lse[:, None]) dp = tl.dot(do, v.T) - ds = sm_scale * p * (dp - d) + ds = sm_scale * p * (dp - d[:, None]) # cast dtype p = p.to(do.dtype) ds = ds.to(q.dtype) - # update dk and dv - dk += tl.dot(ds.T, q) - dv += tl.dot(p.T, do) + # update dk and dv (only for valid queries) + valid_mask_2d = valid_q_mask[:, None] + dk += tl.dot(tl.where(valid_mask_2d, ds, 0.0).T, q) + dv += tl.dot(tl.where(valid_mask_2d, p, 0.0).T, do) # save dk dv tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1)) tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1)) - @triton.jit def backward_dq( q_ptr, # Q: n x qh x d