Skip to content

Commit 155839b

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:triton] Emit a better error message for matmul with non-2D operands
Triton seems to support both 2D and 3D operands now, the latter case being a batched matmul. We need more changes in the lowering to support 3D, so I will leave it out of scope here. Fixes jax-ml#26013. PiperOrigin-RevId: 733293299
1 parent 8906f28 commit 155839b

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

jax/_src/pallas/triton/lowering.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2261,6 +2261,9 @@ def _dot_general_lowering(
22612261

22622262
a_type = ir.RankedTensorType(a.type)
22632263
b_type = ir.RankedTensorType(b.type)
2264+
if len(a_type.shape) != len(b_type.shape) != 2:
2265+
raise ValueError("a and b must be 2D, but got:"
2266+
f" {a_type.shape} and {b_type.shape}")
22642267
if min(*b_type.shape) < 16:
22652268
raise ValueError("all dimensions of b must be >= 16 ")
22662269
if a_type.element_type != b_type.element_type:

tests/pallas/pallas_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,25 @@ def dot_kernel(x_ref, y_ref, o_ref):
733733
)
734734
self.assertAllClose(dot_kernel(x, y), expected, atol=5e-2, rtol=5e-3)
735735

736+
def test_dot_with_vector(self):
737+
if not jtu.test_device_matches(["gpu"]) or self.INTERPRET:
738+
self.skipTest(
739+
"jnp.dot is only restricted to 2D on GPU in non-interpret mode."
740+
)
741+
742+
@functools.partial(
743+
self.pallas_call,
744+
out_shape=jax.ShapeDtypeStruct((32,), jnp.float32),
745+
)
746+
def dot_kernel(x_ref, y_ref, o_ref):
747+
o_ref[()] = jnp.dot(x_ref[()], y_ref[()])
748+
749+
key0, key1 = random.split(random.key(0))
750+
x = random.normal(key0, (32, 64), dtype=jnp.float32)
751+
y = random.normal(key1, (64,), dtype=jnp.float32)
752+
with self.assertRaisesRegex(Exception, "must be 2D"):
753+
dot_kernel(x, y)
754+
736755
@parameterized.parameters(jnp.int4, jnp.uint4)
737756
def test_subbyte_load(self, dtype):
738757
if not jtu.test_device_matches(["gpu"]):

0 commit comments

Comments
 (0)