Skip to content

Commit c5aa86a

Browse files
Remove redundant filtering in the paged flash attention kernel
Reason: `l_next >= 1.0` so the `jnp.where(l_next == 0.0, 1.0, l_next)` clause is not needed. PiperOrigin-RevId: 741400472
1 parent a52f7b2 commit c5aa86a

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,14 +274,13 @@ def prefetch_next_block(): # pylint: disable=unused-variable
274274
alpha = jnp.exp(m_prev - m_next)
275275
beta = jnp.exp(m_curr - m_next)
276276
l_next = alpha * l_prev + beta * l_curr
277-
l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next)
277+
m_ref[...], l_ref[...] = m_next, l_next
278278

279279
v = async_copy_v.wait_and_get_loaded()
280280
o_curr_times_l_curr = jnp.dot(s_curr, v)
281281

282-
m_ref[...], l_ref[...] = m_next, l_next_safe
283282
o_ref[...] = (
284-
(l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe
283+
(l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next
285284
).astype(o_ref.dtype)
286285

287286
step_ref[0] = step + 1

0 commit comments

Comments
 (0)