Skip to content

Commit f9f97d4

Browse files
committed
test: Moved triton quick return to skip and reduced test sizes
Signed-off-by: Brandon Groth <[email protected]>
1 parent 67af0b1 commit f9f97d4

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# Third Party
1818
import torch
1919

20-
# First Party
20+
# Local
2121
from fms_mo.utils.import_utils import available_packages
2222

2323
# Assume any calls to the file are requesting triton
@@ -27,11 +27,11 @@
2727
)
2828

2929
# Third Party
30+
# pylint: disable=wrong-import-position
3031
from triton.language.extra import libdevice
3132
import triton
3233
import triton.language as tl
3334

34-
3535
DTYPE_I8 = [torch.int8]
3636
DTYPE_F8 = [torch.float8_e4m3fn, torch.float8_e5m2]
3737
DTYPE_8BIT = DTYPE_I8 + DTYPE_F8

tests/triton_kernels/test_triton_mm.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
)
3333

3434

35-
@pytest.mark.parametrize("mkn", [64, 256, 1024, 4096])
35+
@pytest.mark.parametrize("mkn", [64, 256, 1024])
3636
@pytest.mark.parametrize(
3737
"dtype_to_test",
3838
[
@@ -43,11 +43,12 @@
4343
torch.float8_e5m2,
4444
],
4545
)
46+
@pytest.mark.skipif(
47+
not torch.cuda.is_available(),
48+
reason="test_triton_matmul_fp can only when GPU is available",
49+
)
4650
def test_triton_matmul_fp(mkn, dtype_to_test):
4751
"""Parametric tests for triton matmul kernel using variety of tensor sizes and dtypes."""
48-
if not torch.cuda.is_available():
49-
# only run the test when GPU is available
50-
return
5152

5253
torch.manual_seed(23)
5354
m = n = k = mkn
@@ -79,12 +80,13 @@ def test_triton_matmul_fp(mkn, dtype_to_test):
7980
assert torch.norm(diff_trun_8b) / torch.norm(torch_output) < 1e-3
8081

8182

82-
@pytest.mark.parametrize("mkn", [64, 256, 1024, 4096])
83+
@pytest.mark.parametrize("mkn", [64, 256, 1024])
84+
@pytest.mark.skipif(
85+
not torch.cuda.is_available(),
86+
reason="test_triton_matmul_int8 can only when GPU is available",
87+
)
8388
def test_triton_matmul_int8(mkn):
8489
"""Parametric tests for triton imatmul kernel using variety of tensor sizes."""
85-
if not torch.cuda.is_available():
86-
# only run the test when GPU is available
87-
return
8890

8991
torch.manual_seed(23)
9092
m = n = k = mkn
@@ -121,13 +123,14 @@ def test_triton_matmul_int8(mkn):
121123

122124
@pytest.mark.parametrize("feat_in_out", [(64, 128), (256, 1024), (1024, 4096)])
123125
@pytest.mark.parametrize("trun_bits", [0, 8, 12, 16])
126+
@pytest.mark.skipif(
127+
not torch.cuda.is_available(),
128+
reason="test_linear_fpx_acc can only when GPU is available",
129+
)
124130
def test_linear_fpx_acc(feat_in_out, trun_bits):
125131
"""Parametric tests for LinearFPxAcc. This Linear utilizes triton kernel hence can only be run
126132
on CUDA.
127133
"""
128-
if not torch.cuda.is_available():
129-
# only run the test when GPU is available
130-
return
131134

132135
torch.manual_seed(23)
133136
feat_in, feat_out = feat_in_out

0 commit comments

Comments
 (0)