Skip to content

Commit 2c165bf

Browse files
[pallas:triton] Lift dot_general restriction on minimal tile size for a.
PiperOrigin-RevId: 725605869
1 parent c502332 commit 2c165bf

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

jax/_src/pallas/triton/lowering.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2175,8 +2175,8 @@ def _dot_general_lowering(
21752175

21762176
a_type = ir.RankedTensorType(a.type)
21772177
b_type = ir.RankedTensorType(b.type)
2178-
if min(*a_type.shape, *b_type.shape) < 16:
2179-
raise ValueError("all dimensions of a and b must be >= 16 ")
2178+
if min(*b_type.shape) < 16:
2179+
raise ValueError("all dimensions of b must be >= 16 ")
21802180
if a_type.element_type != b_type.element_type:
21812181
raise ValueError(
21822182
"a and b must have the same element type, but got:"

0 commit comments

Comments
 (0)