Skip to content

Commit 19277de

Browse files
authored
[NVIDIA] Update min_dot_sizes (#7411)
The original sizes (16 on N) are unnecessarily high
1 parent fcf3e3e commit 19277de

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

python/test/unit/language/test_compile_errors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,9 @@ def test_min_dot_size(dtype):
396396
error_msg = "Input shapes should have "
397397
if is_cuda():
398398
if dtype.primitive_bitwidth == 8:
399-
error_msg += "M >= 16, N >= 16 and K >= 32"
399+
error_msg += "M >= 16, N >= 8 and K >= 32"
400400
else:
401-
error_msg = "M >= 16, N >= 16 and K >= 16"
401+
error_msg = "M >= 16, N >= 8 and K >= 16"
402402
elif is_hip():
403403
# hip supports arbitrary sizes
404404
error_msg = None

third_party/nvidia/backend/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]: # [m,
2323
rhs_bitwidth = rhs_type.scalar.primitive_bitwidth
2424
assert lhs_bitwidth == rhs_bitwidth, "lhs and rhs bitwidth must be the same"
2525
if lhs_bitwidth == 8:
26-
return (16, 16, 32)
26+
return (16, 8, 32)
2727
else:
28-
return (16, 16, 16)
28+
return (16, 8, 16)
2929

3030
return check_dot_compatibility
3131

0 commit comments

Comments
 (0)