Skip to content

Commit 4f6b11c

Browse files
authored
[FRONTEND] Relax minimum dot size on Nvidia target (#7451)
We now support using tensorcores with padding so we can relax the minimum dot size
1 parent 17a2be8 commit 4f6b11c

File tree

3 files changed

+19
-8
lines changed

3 files changed

+19
-8
lines changed

python/test/unit/language/test_compile_errors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,9 @@ def test_min_dot_size(dtype):
396396
error_msg = "Input shapes should have "
397397
if is_cuda():
398398
if dtype.primitive_bitwidth == 8:
399-
error_msg += "M >= 16, N >= 8 and K >= 32"
399+
error_msg += "M >= 1, N >= 1 and K >= 32"
400400
else:
401-
error_msg = "M >= 16, N >= 8 and K >= 16"
401+
error_msg = "M >= 1, N >= 1 and K >= 16"
402402
elif is_hip():
403403
# hip supports arbitrary sizes
404404
error_msg = None

python/test/unit/language/test_core.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3838,6 +3838,13 @@ def get_test_dot_vdot2_cases():
38383838
(4, 32, 32, 4, False, False, 'None', 'ieee', 'bfloat16', 'float32', 1, None)]
38393839

38403840

3841+
def get_test_small_dots_cases():
3842+
if not is_cuda():
3843+
return []
3844+
return [(2, 4, 32, 1, False, False, 'None', 'ieee', 'float16', 'float32', 1, None),
3845+
(1, 2, 32, 1, False, False, 'None', 'ieee', 'float8e5', 'float32', 1, None)]
3846+
3847+
38413848
@pytest.mark.interpreter
38423849
@pytest.mark.parametrize(
38433850
"M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size",
@@ -3851,15 +3858,16 @@ def get_test_dot_vdot2_cases():
38513858
get_test_dot_fp8_output_cases() + \
38523859
get_test_dot_small_k_mfma_cases() + \
38533860
get_test_dot_small_mn_fma_cases() + \
3854-
get_test_dot_softmax())
3861+
get_test_dot_softmax() + \
3862+
get_test_small_dots_cases())
38553863
@pytest.mark.parametrize("num_ctas", num_ctas_list)
38563864
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size,
38573865
num_ctas, device):
38583866
if is_interpreter():
38593867
if in_dtype == 'bfloat16':
38603868
pytest.skip("bfloat16 is not supported in the interpreter")
38613869
else:
3862-
if not is_hip() and (M < 16 or N < 16 or K < 16):
3870+
if not is_hip() and K < 16:
38633871
pytest.skip("small dots are supported only on HIP at the moment")
38643872
if is_cuda():
38653873
capability = torch.cuda.get_device_capability()
@@ -4097,10 +4105,12 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
40974105
assert 'wgmma.mma_async.sync.aligned' in ptx or\
40984106
'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
40994107
elif in_dtype == "float8e5" and out_dtype == tl.float32:
4100-
if capability[0] == 9:
4108+
if capability[0] == 9 and M >= 64 and N >= 8:
41014109
assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2' in ptx
4110+
elif capability[0] >= 8 and M < 64:
4111+
assert 'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32' in ptx
41024112
elif in_dtype == "float8e4nv" and out_dtype == tl.float32:
4103-
if capability[0] == 9:
4113+
if capability[0] == 9 and M >= 64 and N >= 8:
41044114
assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx
41054115
if is_tcgen5 and epilogue == 'softmax' and M >= 128:
41064116
# check that there is no shared memory exchange in the softmax

third_party/nvidia/backend/compiler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]: # [m,
2222
lhs_bitwidth = lhs_type.scalar.primitive_bitwidth
2323
rhs_bitwidth = rhs_type.scalar.primitive_bitwidth
2424
assert lhs_bitwidth == rhs_bitwidth, "lhs and rhs bitwidth must be the same"
25+
# For small M/N the input we can still use tensorcores with padding.
2526
if lhs_bitwidth == 8:
26-
return (16, 8, 32)
27+
return (1, 1, 32)
2728
else:
28-
return (16, 8, 16)
29+
return (1, 1, 16)
2930

3031
return check_dot_compatibility
3132

0 commit comments

Comments
 (0)