|
32 | 32 | ) |
33 | 33 |
|
34 | 34 |
|
35 | | -@pytest.mark.parametrize("mkn", [64, 256, 1024, 4096]) |
| 35 | +@pytest.mark.parametrize("mkn", [64, 256, 1024]) |
36 | 36 | @pytest.mark.parametrize( |
37 | 37 | "dtype_to_test", |
38 | 38 | [ |
|
43 | 43 | torch.float8_e5m2, |
44 | 44 | ], |
45 | 45 | ) |
| 46 | +@pytest.mark.skipif( |
| 47 | + not torch.cuda.is_available(), |
| 48 | + reason="test_triton_matmul_fp can only when GPU is available", |
| 49 | +) |
46 | 50 | def test_triton_matmul_fp(mkn, dtype_to_test): |
47 | 51 | """Parametric tests for triton matmul kernel using variety of tensor sizes and dtypes.""" |
48 | | - if not torch.cuda.is_available(): |
49 | | - # only run the test when GPU is available |
50 | | - return |
51 | 52 |
|
52 | 53 | torch.manual_seed(23) |
53 | 54 | m = n = k = mkn |
@@ -79,12 +80,13 @@ def test_triton_matmul_fp(mkn, dtype_to_test): |
79 | 80 | assert torch.norm(diff_trun_8b) / torch.norm(torch_output) < 1e-3 |
80 | 81 |
|
81 | 82 |
|
82 | | -@pytest.mark.parametrize("mkn", [64, 256, 1024, 4096]) |
| 83 | +@pytest.mark.parametrize("mkn", [64, 256, 1024]) |
| 84 | +@pytest.mark.skipif( |
| 85 | + not torch.cuda.is_available(), |
| 86 | + reason="test_triton_matmul_int8 can only when GPU is available", |
| 87 | +) |
83 | 88 | def test_triton_matmul_int8(mkn): |
84 | 89 | """Parametric tests for triton imatmul kernel using variety of tensor sizes.""" |
85 | | - if not torch.cuda.is_available(): |
86 | | - # only run the test when GPU is available |
87 | | - return |
88 | 90 |
|
89 | 91 | torch.manual_seed(23) |
90 | 92 | m = n = k = mkn |
@@ -121,13 +123,14 @@ def test_triton_matmul_int8(mkn): |
121 | 123 |
|
122 | 124 | @pytest.mark.parametrize("feat_in_out", [(64, 128), (256, 1024), (1024, 4096)]) |
123 | 125 | @pytest.mark.parametrize("trun_bits", [0, 8, 12, 16]) |
| 126 | +@pytest.mark.skipif( |
| 127 | + not torch.cuda.is_available(), |
| 128 | + reason="test_linear_fpx_acc can only when GPU is available", |
| 129 | +) |
124 | 130 | def test_linear_fpx_acc(feat_in_out, trun_bits): |
125 | 131 | """Parametric tests for LinearFPxAcc. This Linear utilizes triton kernel hence can only be run |
126 | 132 | on CUDA. |
127 | 133 | """ |
128 | | - if not torch.cuda.is_available(): |
129 | | - # only run the test when GPU is available |
130 | | - return |
131 | 134 |
|
132 | 135 | torch.manual_seed(23) |
133 | 136 | feat_in, feat_out = feat_in_out |
|
0 commit comments