1717from collections .abc import Sequence
1818import functools
1919import itertools
20+ import math
2021import sys
2122from typing import Any
2223import unittest
6263floatx = dtypes .canonicalize_dtype (jnp .float64 )
6364
6465
66+ def is_power_of_two (n : int ) -> bool :
67+ return (n > 0 ) and (n & (n - 1 ) == 0 )
68+
69+
6570def smem_on_tpu ():
6671 if jtu .test_device_matches (["tpu" ]):
6772 return pltpu .SMEM
@@ -1410,12 +1415,45 @@ def f(x_ref, o_ref):
14101415 np .testing .assert_allclose (f (x ), expected )
14111416
14121417 @parameterized .product (
1413- size = [16 , 32 , 64 , 128 , 256 ],
1418+ lhs_and_rhs_shape = [
1419+ ((16 , 16 ), (16 , 16 )),
1420+ ((32 , 32 ), (32 , 32 )),
1421+ ((64 , 64 ), (64 , 64 )),
1422+ ((128 , 128 ), (128 , 128 )),
1423+ ((256 , 256 ), (256 , 256 )),
1424+ ((8 , 128 ), (128 , 256 )),
1425+ ((8 , 128 ), (256 , 128 )),
1426+ ((8 , 256 ), (256 , 128 )),
1427+ ((16 , 128 ), (128 , 256 )),
1428+ ((16 , 128 ), (256 , 128 )),
1429+ ((16 , 256 ), (256 , 128 )),
1430+ ((24 , 128 ), (128 , 256 )),
1431+ ((24 , 128 ), (256 , 128 )),
1432+ ((24 , 256 ), (256 , 128 )),
1433+ ((128 , 8 ), (128 , 256 )),
1434+ ((128 , 8 ), (256 , 128 )),
1435+ ((256 , 8 ), (256 , 128 )),
1436+ ((128 , 16 ), (128 , 256 )),
1437+ ((128 , 16 ), (256 , 128 )),
1438+ ((256 , 16 ), (256 , 128 )),
1439+ ((128 , 24 ), (128 , 256 )),
1440+ ((128 , 24 ), (256 , 128 )),
1441+ ((256 , 24 ), (256 , 128 )),
1442+ ],
14141443 dtype = [jnp .float32 , jnp .float16 , jnp .bfloat16 ],
14151444 trans_x = [False , True ],
14161445 trans_y = [False , True ],
14171446 )
1418- def test_dot (self , size , dtype , trans_x , trans_y ):
1447+ def test_dot (self , lhs_and_rhs_shape , dtype , trans_x , trans_y ):
1448+ lhs_shape , rhs_shape = lhs_and_rhs_shape
1449+
1450+ final_lhs_shape = lhs_shape [::- 1 ] if trans_x else lhs_shape
1451+ final_rhs_shape = rhs_shape [::- 1 ] if trans_y else rhs_shape
1452+ if final_lhs_shape [1 ] != final_rhs_shape [0 ]:
1453+ self .skipTest ("Contraction dimensions do not match" )
1454+
1455+ out_shape = (final_lhs_shape [0 ], final_rhs_shape [1 ])
1456+
14191457 if jtu .test_device_matches (["tpu" ]):
14201458 if dtype == jnp .float16 :
14211459 self .skipTest ("float16 type is not supported on TPU" )
@@ -1427,12 +1465,19 @@ def test_dot(self, size, dtype, trans_x, trans_y):
14271465 if jtu .test_device_matches (["gpu" ]):
14281466 if dtype == jnp .bfloat16 :
14291467 self .skipTest ("bfloat16 type are not supported on GPU" )
1430- if size > 128 :
1468+ if (
1469+ math .prod (lhs_shape ) + math .prod (rhs_shape ) + math .prod (out_shape )
1470+ > (256 * 256 ) * 2
1471+ ):
14311472 self .skipTest ("Shared memory size limit exceeded" )
1473+ if min (* lhs_shape , * rhs_shape ) < 16 :
1474+ self .skipTest ("All dimensions of lhs and rhs must be >= 16" )
1475+ if any (not is_power_of_two (x ) for x in lhs_shape + rhs_shape ):
1476+ self .skipTest ("All dimensions of lhs and rhs must be power of two" )
14321477
14331478 @functools .partial (
14341479 self .pallas_call ,
1435- out_shape = jax .ShapeDtypeStruct (( size , size ) , dtype ),
1480+ out_shape = jax .ShapeDtypeStruct (out_shape , dtype ),
14361481 grid = 1 ,
14371482 )
14381483 def dot (x_ref , y_ref , o_ref ):
@@ -1441,8 +1486,8 @@ def dot(x_ref, y_ref, o_ref):
14411486 o_ref [:, :] = pl .dot (x , y , trans_x , trans_y ).astype (o_ref .dtype )
14421487
14431488 k1 , k2 = random .split (random .key (0 ))
1444- x = random .normal (k1 , ( size , size ) , dtype = dtype )
1445- y = random .normal (k2 , ( size , size ) , dtype = dtype )
1489+ x = random .normal (k1 , lhs_shape , dtype = dtype )
1490+ y = random .normal (k2 , rhs_shape , dtype = dtype )
14461491 out = dot (x , y )
14471492 expected = jnp .dot (x .T if trans_x else x , y .T if trans_y else y )
14481493 np .testing .assert_allclose (
0 commit comments