We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c502332 commit 2c165bfCopy full SHA for 2c165bf
jax/_src/pallas/triton/lowering.py
@@ -2175,8 +2175,8 @@ def _dot_general_lowering(
2175
2176
a_type = ir.RankedTensorType(a.type)
2177
b_type = ir.RankedTensorType(b.type)
2178
- if min(*a_type.shape, *b_type.shape) < 16:
2179
- raise ValueError("all dimensions of a and b must be >= 16 ")
+ if min(*b_type.shape) < 16:
+ raise ValueError("all dimensions of b must be >= 16 ")
2180
if a_type.element_type != b_type.element_type:
2181
raise ValueError(
2182
"a and b must have the same element type, but got:"
0 commit comments