Skip to content

Commit e418e88

Browse files
WindQAQGoogle-ML-Automation
authored andcommitted
[Pallas] Add non-square pl.dot test cases.
PiperOrigin-RevId: 704788500
1 parent 593143e commit e418e88

File tree

1 file changed

+51
-6
lines changed

1 file changed

+51
-6
lines changed

tests/pallas/ops_test.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from collections.abc import Sequence
1818
import functools
1919
import itertools
20+
import math
2021
import sys
2122
from typing import Any
2223
import unittest
@@ -62,6 +63,10 @@
6263
floatx = dtypes.canonicalize_dtype(jnp.float64)
6364

6465

66+
def is_power_of_two(n: int) -> bool:
67+
return (n > 0) and (n & (n - 1) == 0)
68+
69+
6570
def smem_on_tpu():
6671
if jtu.test_device_matches(["tpu"]):
6772
return pltpu.SMEM
@@ -1410,12 +1415,45 @@ def f(x_ref, o_ref):
14101415
np.testing.assert_allclose(f(x), expected)
14111416

14121417
@parameterized.product(
1413-
size=[16, 32, 64, 128, 256],
1418+
lhs_and_rhs_shape=[
1419+
((16, 16), (16, 16)),
1420+
((32, 32), (32, 32)),
1421+
((64, 64), (64, 64)),
1422+
((128, 128), (128, 128)),
1423+
((256, 256), (256, 256)),
1424+
((8, 128), (128, 256)),
1425+
((8, 128), (256, 128)),
1426+
((8, 256), (256, 128)),
1427+
((16, 128), (128, 256)),
1428+
((16, 128), (256, 128)),
1429+
((16, 256), (256, 128)),
1430+
((24, 128), (128, 256)),
1431+
((24, 128), (256, 128)),
1432+
((24, 256), (256, 128)),
1433+
((128, 8), (128, 256)),
1434+
((128, 8), (256, 128)),
1435+
((256, 8), (256, 128)),
1436+
((128, 16), (128, 256)),
1437+
((128, 16), (256, 128)),
1438+
((256, 16), (256, 128)),
1439+
((128, 24), (128, 256)),
1440+
((128, 24), (256, 128)),
1441+
((256, 24), (256, 128)),
1442+
],
14141443
dtype=[jnp.float32, jnp.float16, jnp.bfloat16],
14151444
trans_x=[False, True],
14161445
trans_y=[False, True],
14171446
)
1418-
def test_dot(self, size, dtype, trans_x, trans_y):
1447+
def test_dot(self, lhs_and_rhs_shape, dtype, trans_x, trans_y):
1448+
lhs_shape, rhs_shape = lhs_and_rhs_shape
1449+
1450+
final_lhs_shape = lhs_shape[::-1] if trans_x else lhs_shape
1451+
final_rhs_shape = rhs_shape[::-1] if trans_y else rhs_shape
1452+
if final_lhs_shape[1] != final_rhs_shape[0]:
1453+
self.skipTest("Contraction dimensions do not match")
1454+
1455+
out_shape = (final_lhs_shape[0], final_rhs_shape[1])
1456+
14191457
if jtu.test_device_matches(["tpu"]):
14201458
if dtype == jnp.float16:
14211459
self.skipTest("float16 type is not supported on TPU")
@@ -1427,12 +1465,19 @@ def test_dot(self, size, dtype, trans_x, trans_y):
14271465
if jtu.test_device_matches(["gpu"]):
14281466
if dtype == jnp.bfloat16:
14291467
self.skipTest("bfloat16 type are not supported on GPU")
1430-
if size > 128:
1468+
if (
1469+
math.prod(lhs_shape) + math.prod(rhs_shape) + math.prod(out_shape)
1470+
> (256 * 256) * 2
1471+
):
14311472
self.skipTest("Shared memory size limit exceeded")
1473+
if min(*lhs_shape, *rhs_shape) < 16:
1474+
self.skipTest("All dimensions of lhs and rhs must be >= 16")
1475+
if any(not is_power_of_two(x) for x in lhs_shape + rhs_shape):
1476+
self.skipTest("All dimensions of lhs and rhs must be power of two")
14321477

14331478
@functools.partial(
14341479
self.pallas_call,
1435-
out_shape=jax.ShapeDtypeStruct((size, size), dtype),
1480+
out_shape=jax.ShapeDtypeStruct(out_shape, dtype),
14361481
grid=1,
14371482
)
14381483
def dot(x_ref, y_ref, o_ref):
@@ -1441,8 +1486,8 @@ def dot(x_ref, y_ref, o_ref):
14411486
o_ref[:, :] = pl.dot(x, y, trans_x, trans_y).astype(o_ref.dtype)
14421487

14431488
k1, k2 = random.split(random.key(0))
1444-
x = random.normal(k1, (size, size), dtype=dtype)
1445-
y = random.normal(k2, (size, size), dtype=dtype)
1489+
x = random.normal(k1, lhs_shape, dtype=dtype)
1490+
y = random.normal(k2, rhs_shape, dtype=dtype)
14461491
out = dot(x, y)
14471492
expected = jnp.dot(x.T if trans_x else x, y.T if trans_y else y)
14481493
np.testing.assert_allclose(

0 commit comments

Comments
 (0)