Skip to content

Commit 76d5507

Browse files
committed
more guards
1 parent a1051da commit 76d5507

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

native_sparse_attention_pytorch/triton_native_sparse_attention.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -935,8 +935,17 @@ def backward_kernel_one_col_block_sparse(
935935
offs_m * stride_kvbl_m
936936
)
937937

938-
block_indices = tl.load(kv_block_indices_ptrs + OFF_SEL_KV_BLOCKS)
939-
block_masks = tl.load(kv_block_mask_ptrs + OFF_SEL_KV_BLOCKS)
938+
block_indices = tl.load(
939+
kv_block_indices_ptrs + OFF_SEL_KV_BLOCKS,
940+
mask = offs_m < seqlen_q,
941+
other = 0.
942+
)
943+
944+
block_masks = tl.load(
945+
kv_block_mask_ptrs + OFF_SEL_KV_BLOCKS,
946+
mask = offs_m < seqlen_q,
947+
other = 0.
948+
)
940949

941950
blocks_offs_n = (
942951
block_indices[:, None] * BLOCK +

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "native-sparse-attention-pytorch"
3-
version = "0.1.8"
3+
version = "0.1.9"
44
description = "Native Sparse Attention"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)