Skip to content

Commit 7ace72f

Browse files
committed
[Pallas] Be explicit about accumulation dtype in reference implementations
1 parent 6004a50 commit 7ace72f

File tree

3 files changed

+16
-7
lines changed

3 files changed

+16
-7
lines changed

jax/experimental/pallas/ops/gpu/paged_attention.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,8 @@ def paged_attention_reference(
394394
) # [batch_size, num_kv_heads, kv_seq_len, head_dim]
395395

396396
uncapped_logits = jnp.einsum(
397-
"bkgd,bksd->bkgs", q_reshaped, k_transposed
397+
"bkgd,bksd->bkgs", q_reshaped, k_transposed,
398+
preferred_element_type=jnp.float32
398399
).astype(jnp.float32)
399400

400401
if attn_logits_soft_cap is not None:
@@ -410,7 +411,8 @@ def paged_attention_reference(
410411

411412
weights = jax.nn.softmax(logits, axis=-1)
412413
o = jnp.einsum(
413-
"bkgs,bksd->bkgd", weights, v_transposed.astype(jnp.float32)
414+
"bkgs,bksd->bkgd", weights, v_transposed.astype(jnp.float32),
415+
preferred_element_type=jnp.float32
414416
).astype(q.dtype)
415417
o = o.reshape(q.shape)
416418

tests/pallas/ops_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1652,7 +1652,10 @@ def dot(x_ref, y_ref, o_ref):
16521652
x = random.normal(k1, lhs_shape, dtype=dtype)
16531653
y = random.normal(k2, rhs_shape, dtype=dtype)
16541654
out = dot(x, y)
1655-
expected = jnp.dot(x.T if trans_x else x, y.T if trans_y else y)
1655+
# Pallas always accumulates in FP32, so we are explicit about
1656+
# preferred_element_type here.
1657+
expected = jnp.dot(x.T if trans_x else x, y.T if trans_y else y,
1658+
preferred_element_type=jnp.float32).astype(dtype)
16561659
np.testing.assert_allclose(
16571660
out.astype(jnp.float32),
16581661
expected.astype(jnp.float32),

tests/pallas/pallas_test.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -553,8 +553,10 @@ def test_matmul(self, m, n, k, dtype, bm, bn, bk, gm):
553553
k1, k2 = random.split(random.key(0))
554554
x = random.normal(k1, (m, k), dtype=dtype)
555555
y = random.normal(k2, (k, n), dtype=dtype)
556-
out, expected = matmul(x, y, bm=bm, bn=bn, bk=bk, gm=gm,
557-
interpret=self.INTERPRET), jnp.matmul(x, y)
556+
out = matmul(x, y, bm=bm, bn=bn, bk=bk, gm=gm,
557+
interpret=self.INTERPRET)
558+
expected = jnp.matmul(
559+
x, y, preferred_element_type=jnp.float32).astype(dtype)
558560
np.testing.assert_allclose(out, expected, atol=0.05, rtol=0.05)
559561

560562
@parameterized.named_parameters(*[
@@ -576,8 +578,10 @@ def test_matmul_block_spec(self, m, n, k, dtype, bm, bn, bk):
576578
k1, k2 = random.split(random.key(0))
577579
x = random.normal(k1, (m, k), dtype=dtype)
578580
y = random.normal(k2, (k, n), dtype=dtype)
579-
out, expected = matmul_block_spec(x, y, bm=bm, bn=bn, bk=bk,
580-
interpret=self.INTERPRET), jnp.matmul(x, y)
581+
out = matmul_block_spec(x, y, bm=bm, bn=bn, bk=bk,
582+
interpret=self.INTERPRET)
583+
expected = jnp.matmul(
584+
x, y, preferred_element_type=jnp.float32).astype(dtype)
581585
np.testing.assert_allclose(out, expected, atol=0.05, rtol=0.05)
582586

583587
@parameterized.named_parameters(*(

0 commit comments

Comments
 (0)