Skip to content

Commit 70dbcdd

Browse files
committed
Fix varlen seqused_k behaviour with 0-length sequences.
1 parent 1ca96c7 commit 70dbcdd

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/flash_attn_jax/varlen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ def _flash_mha_varlen_vjp_bwd(config, pack, dout):
3434
ends = seqlens_k[1:]
3535
lens = ends - starts
3636
zero = jnp.zeros(q.shape[0], dtype=jnp.int32)
37-
ixl = jnp.arange(q.shape[0]) - jnp.cumsum(zero.at[ends].set(lens))
38-
limits = jnp.cumsum(zero.at[starts].set(seqused_k-jnp.concatenate([jnp.array([0]), seqused_k[:-1]])))
37+
ixl = jnp.arange(q.shape[0]) - jnp.cumsum(zero.at[ends].add(lens))
38+
limits = jnp.cumsum(zero.at[starts].add(seqused_k-jnp.concatenate([jnp.array([0]), seqused_k[:-1]])))
3939
mask = (ixl < limits)
4040
v = v * mask[:, None, None]
4141
k = k * mask[:, None, None]

tests/test_varlen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def check1(ref_out, jax_out, out):
4242
@pytest.mark.parametrize("m", [1, 2]) # for MQA/GQA
4343
def test_varlen_flash_fwd(m, h, d, causal, local, dtype, seqused_k_limit):
4444
window_size = (3,3) if local else (-1,-1)
45-
lens = [1, 2, 6, 10]
45+
lens = [1, 2, 0, 6, 10]
4646
b = len(lens)
4747
total_seqlen = sum(lens)
4848

@@ -91,7 +91,7 @@ def ref(q,k,v):
9191
@pytest.mark.parametrize("m", [1, 2]) # for MQA/GQA
9292
def test_varlen_flash_bwd(m, h, d, causal, local, dtype, seqused_k_limit):
9393
window_size = (3,3) if local else (-1,-1)
94-
lens = [1, 2, 4, 6, 8, 10]
94+
lens = [1, 2, 0, 6, 10]
9595
b = len(lens)
9696
total_seqlen = sum(lens)
9797
if seqused_k_limit is not None and (causal or local):

0 commit comments

Comments
 (0)