We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 1ca96c7 commit 70dbcddCopy full SHA for 70dbcdd
src/flash_attn_jax/varlen.py
@@ -34,8 +34,8 @@ def _flash_mha_varlen_vjp_bwd(config, pack, dout):
34
ends = seqlens_k[1:]
35
lens = ends - starts
36
zero = jnp.zeros(q.shape[0], dtype=jnp.int32)
37
- ixl = jnp.arange(q.shape[0]) - jnp.cumsum(zero.at[ends].set(lens))
38
- limits = jnp.cumsum(zero.at[starts].set(seqused_k-jnp.concatenate([jnp.array([0]), seqused_k[:-1]])))
+ ixl = jnp.arange(q.shape[0]) - jnp.cumsum(zero.at[ends].add(lens))
+ limits = jnp.cumsum(zero.at[starts].add(seqused_k-jnp.concatenate([jnp.array([0]), seqused_k[:-1]])))
39
mask = (ixl < limits)
40
v = v * mask[:, None, None]
41
k = k * mask[:, None, None]
tests/test_varlen.py
@@ -42,7 +42,7 @@ def check1(ref_out, jax_out, out):
42
@pytest.mark.parametrize("m", [1, 2]) # for MQA/GQA
43
def test_varlen_flash_fwd(m, h, d, causal, local, dtype, seqused_k_limit):
44
window_size = (3,3) if local else (-1,-1)
45
- lens = [1, 2, 6, 10]
+ lens = [1, 2, 0, 6, 10]
46
b = len(lens)
47
total_seqlen = sum(lens)
48
@@ -91,7 +91,7 @@ def ref(q,k,v):
91
92
def test_varlen_flash_bwd(m, h, d, causal, local, dtype, seqused_k_limit):
93
94
- lens = [1, 2, 4, 6, 8, 10]
95
96
97
if seqused_k_limit is not None and (causal or local):
0 commit comments