Commit c5aa86a
committed
Remove redundant filtering in the paged flash attention kernel
Reason: `l_next >= 1.0` so the `jnp.where(l_next == 0.0, 1.0, l_next)` clause is not needed.
PiperOrigin-RevId: 7414004721 parent a52f7b2 commit c5aa86a
File tree
1 file changed
+2
-3
lines changed- jax/experimental/pallas/ops/tpu/paged_attention
1 file changed
+2
-3
lines changedLines changed: 2 additions & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
274 | 274 | | |
275 | 275 | | |
276 | 276 | | |
277 | | - | |
| 277 | + | |
278 | 278 | | |
279 | 279 | | |
280 | 280 | | |
281 | 281 | | |
282 | | - | |
283 | 282 | | |
284 | | - | |
| 283 | + | |
285 | 284 | | |
286 | 285 | | |
287 | 286 | | |
| |||
0 commit comments