Skip to content

Commit cf21f73

Browse files
Merge pull request jax-ml#27258 from jakevdp:fix-lint
PiperOrigin-RevId: 738492394
2 parents 84ec21e + 7a67c9b commit cf21f73

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

jax/experimental/pallas/ops/tpu/ragged_paged_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def init_scratch_ref():
431431
causal_mask = row_ids < col_ids
432432
if sliding_window is not None:
433433
causal_mask = jnp.logical_or(causal_mask,
434-
row_ids - sliding_window>=col_ids)
434+
row_ids - sliding_window >= col_ids)
435435
qk += jnp.where(causal_mask, mask_value, 0.0)
436436
m_curr = jnp.max(qk, axis=1, keepdims=True)
437437
s_curr = jnp.exp(qk - m_curr)

tests/pallas/tpu_ragged_paged_attention_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def test_ragged_paged_attention_complex(
306306
)
307307

308308
def test_ragged_paged_attention_sliding_window_should_be_positive(self):
309-
dtype=jnp.float32
309+
dtype = jnp.float32
310310
seq_lens = [(192, 328), (128, 180), (64, 255)]
311311
num_heads = (32, 8)
312312
head_dim = 128

0 commit comments

Comments
 (0)