Skip to content

Commit 07ac3a4

Browse files
Enable python/triton_kernels/tests as optional in test-triton.sh (#4924)
- Add tests from triton_kernels directory to test-triton.sh - Tests can be triggered with test-triton.sh --triton-kernels - ENABLED BY DEFAULT - Ignore tests in test_tensor_details directory to skip NVIDIA HW tests - Added get_device_capability() function to conftest.py with XPU support, returning capability tuple (9,) for XPU devices and (0,) as fallback for unknown devices. - Parametrized device in mxfp and routing tests enabling testing on different backends including XPU. - Added XPU backend support to target_info.py by implementing is_xpu() function for XPU device detection and extending num_sms() to return max_compute_units for XPU devices. - Fixed SwiGLU kernel XPU compatibility by excluding XPU backend from NVIDIA-specific maxnreg parameter.
1 parent 42bc898 commit 07ac3a4

File tree

21 files changed

+110
-25
lines changed

21 files changed

+110
-25
lines changed

.github/workflows/build-test-reusable.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,11 @@ jobs:
311311
run: |
312312
${{ env.TRITON_TEST_CMD }} --interpreter
313313
314+
- name: Run triton kernels tests
315+
if: matrix.suite == 'rest'
316+
run: |
317+
${{ env.TRITON_TEST_CMD }} --triton-kernels
318+
314319
# FIXME: make sure new tutorials are added to one of the groups (scaled_dot, rest, tutorial-faX)
315320
- name: Select tutorials to run (scaled_dot)
316321
if: matrix.suite == 'scaled_dot'

.github/workflows/build-test-windows.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,13 @@ jobs:
148148
cd ${{ env.NEW_WORKSPACE }}
149149
${{ env.TRITON_TEST_CMD }} --core
150150
151+
- name: Run triton kernels tests
152+
run: |
153+
.venv\Scripts\activate.ps1
154+
Invoke-BatchFile "C:\Program Files (x86)\Intel\oneAPI\setvars.bat"
155+
cd ${{ env.NEW_WORKSPACE }}
156+
${{ env.TRITON_TEST_CMD }} --triton-kernels
157+
151158
- name: Run interpreter tests
152159
run: |
153160
.venv\Scripts\activate.ps1

.github/workflows/pip-test-windows.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,13 @@ jobs:
131131
cd ${{ env.NEW_WORKSPACE }}
132132
${{ env.TRITON_TEST_CMD }} --interpreter
133133
134+
- name: Run triton kernels tests
135+
run: |
136+
.venv\Scripts\activate.ps1
137+
Invoke-BatchFile "C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvarsall.bat" x64
138+
cd ${{ env.NEW_WORKSPACE }}
139+
${{ env.TRITON_TEST_CMD }} --triton-kernels
140+
134141
- name: Run tutorials
135142
run: |
136143
.venv\Scripts\activate.ps1

.github/workflows/pip-test.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ jobs:
7171
run: |
7272
${{ env.TRITON_TEST_CMD }} --interpreter --skip-pip-install
7373
74+
- name: Run triton kernels tests
75+
run: |
76+
${{ env.TRITON_TEST_CMD }} --triton-kernels --skip-pip-install
77+
7478
- name: Run Tutorials
7579
run: |
7680
${{ env.TRITON_TEST_CMD }} --tutorial --skip-pip-install

python/triton/language/target_info.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,9 @@ def is_hip_cdna3():
5252
def is_hip_cdna4():
5353
target = current_target()
5454
return target is not None and target.arch == "gfx950"
55+
56+
57+
@constexpr_function
58+
def is_xpu():
59+
target = current_target()
60+
return target is not None and target.backend == "xpu"

python/triton_kernels/tests/test_mxfp.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ def dtype_str_to_torch(dtype_str: str) -> torch.dtype:
1616

1717

1818
@pytest.mark.parametrize("dst_dtype", ["float16", "bfloat16", "float32"])
19-
def test_mxfp4_rounding_cases(dst_dtype):
19+
def test_mxfp4_rounding_cases(dst_dtype, device):
2020
dst_dtype = dtype_str_to_torch(dst_dtype)
21-
x = torch.tensor([6, 0, 0.24, 0.25, 0.75, 0.99, 1.2, 1.3]).cuda().bfloat16().view(1, -1, 1)
21+
x = torch.tensor([6, 0, 0.24, 0.25, 0.75, 0.99, 1.2, 1.3]).to(device).bfloat16().view(1, -1, 1)
2222
quant, scale = downcast_to_mxfp(x, torch.uint8, axis=1)
2323
dequant = upcast_from_mxfp(quant, scale, dst_dtype, axis=1)
2424
assert dequant.flatten().tolist() == [6, 0, 0, 0.5, 1.0, 1.0, 1.0, 1.5], f"{dequant=}"
@@ -33,8 +33,8 @@ def test_mxfp4_rounding_cases(dst_dtype):
3333

3434
@pytest.mark.parametrize("src_dtype", ["float4_e2m1", "float8_e5m2", "float8_e4m3fn"])
3535
@pytest.mark.parametrize("dst_dtype", ["float16", "bfloat16", "float32"])
36-
def test_mxfp_quant_dequant(src_dtype, dst_dtype):
37-
if "float8" in src_dtype and torch.cuda.get_device_capability()[0] < 9:
36+
def test_mxfp_quant_dequant(src_dtype, dst_dtype, device):
37+
if "float8" in src_dtype and (device == "cuda" and torch.cuda.get_device_capability()[0] < 9):
3838
pytest.skip("Float8 not tested on A100")
3939
limit_range = src_dtype == "float8_e5m2" and dst_dtype == "float16"
4040

@@ -48,14 +48,14 @@ def test_mxfp_quant_dequant(src_dtype, dst_dtype):
4848
max_val = 128
4949

5050
# These are all the valid mxfp4 positive values.
51-
pos_vals = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, max_val], device="cuda", dtype=dst_dtype)
51+
pos_vals = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, max_val], device=device, dtype=dst_dtype)
5252
neg_vals = -pos_vals
5353
k_dim = torch.cat([pos_vals, neg_vals])
5454
k_dim = k_dim.reshape([k_dim.shape[0], 1])
5555

5656
# We pick power of 2 scales since both the scales and their inverse only require exponent bits to be exactly
5757
# represented. This means we can store the scales exactly in the e8m0 format.
58-
powers = torch.arange(-8, 8, device="cuda", dtype=dst_dtype)
58+
powers = torch.arange(-8, 8, device=device, dtype=dst_dtype)
5959
scales = 2**powers
6060
scales = scales.reshape([1, powers.shape[0]])
6161
weight = k_dim * scales
@@ -85,13 +85,14 @@ def test_mxfp_casting(
8585
quant_dtype: str,
8686
dequant_dtype: str,
8787
rounding_mode: DequantScaleRoundingMode,
88+
device,
8889
):
89-
if "float8" in quant_dtype and torch.cuda.get_device_capability()[0] < 9:
90+
if "float8" in quant_dtype and (device == "cuda" and torch.cuda.get_device_capability()[0] < 9):
9091
pytest.skip("Float8 not tested on A100")
9192
quant_torch_type = dtype_str_to_torch(quant_dtype)
9293
dequant_torch_type = dtype_str_to_torch(dequant_dtype)
9394
# Generate random input tensor that is contiguous once axis is the last dimension
94-
x = torch.randn(shape, device="cuda", dtype=dequant_torch_type)
95+
x = torch.randn(shape, device=device, dtype=dequant_torch_type)
9596

9697
# Quantize and check equivalence
9798
quant, scale = downcast_to_mxfp(x, quant_torch_type, axis, DEQUANT_SCALE_ROUNDING_MODE=rounding_mode)

python/triton_kernels/tests/test_routing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from triton_kernels.testing import assert_equal
66

77

8-
def init_data(n_tokens, n_expts_tot, dtype=torch.float16, device="cuda"):
8+
def init_data(n_tokens, n_expts_tot, device, dtype=torch.float16):
99
logits = torch.randn((n_tokens, n_expts_tot), dtype=dtype, device=device, requires_grad=True)
1010
return logits
1111

@@ -32,7 +32,7 @@ def test_op(n_tokens_pad, n_tokens_raw, n_expts_tot, n_expts_act, sm_first, use_
3232
ref_logits = tri_logits.clone().detach().requires_grad_(True)
3333

3434
if use_expt_indx:
35-
rand_idx = lambda: torch.randperm(n_expts_tot, device="cuda", dtype=torch.int64)
35+
rand_idx = lambda: torch.randperm(n_expts_tot, device=device, dtype=torch.int64)
3636
tri_expt_indx = torch.stack([rand_idx()[:n_expts_act] for _ in range(n_tokens_pad)])
3737
tri_expt_indx, _ = torch.sort(tri_expt_indx, dim=1)
3838
tri_expt_indx[n_tokens_raw:] = -99999 # should not be used
@@ -76,11 +76,11 @@ def _assert_indx_equal(ref, tri):
7676
assert_close(ref_logits.grad[:n_tokens_raw], tri_logits.grad[:n_tokens_raw])
7777

7878

79-
def bench_routing():
79+
def bench_routing(device):
8080
import triton.profiler as proton
8181
n_tokens = 8192
8282
n_expts_tot, n_expts_act = 128, 4
83-
tri_logits = init_data(n_tokens, n_expts_tot)
83+
tri_logits = init_data(n_tokens, n_expts_tot, device)
8484
proton.start("routing")
8585
proton.activate()
8686
for i in range(100):

python/triton_kernels/tests/test_swiglu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_op(M, N, limit, device, alpha=0.5):
3030
# initialize expert data
3131
n_expts_tot = 6
3232
n_expts_act = 2
33-
logits = init_routing_data(M, n_expts_tot).detach()
33+
logits = init_routing_data(M, n_expts_tot, device).detach()
3434
routing_data, _, _ = routing_torch(logits, n_expts_act)
3535
n_tokens = routing_data.expt_hist.sum()
3636

python/triton_kernels/tests/test_tensor_details/test_layout_blackwell.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
from triton._internal_testing import is_cuda
23
import torch
34
from triton_kernels.tensor_details.layout import BlackwellMXScaleLayout
45

@@ -17,6 +18,7 @@
1718
(3, 2, 36),
1819
],
1920
)
21+
@pytest.mark.xfail(condition=not is_cuda(), reason="Only supported on CUDA")
2022
def test_mxfp4_scale_roundtrip(shape):
2123
x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda")
2224
layout = BlackwellMXScaleLayout(x.shape)

python/triton_kernels/tests/test_tensor_details/test_layout_hopper.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from triton._internal_testing import is_cuda
2+
from triton._internal_testing import is_cuda, is_xpu
33
from triton_kernels.tensor import wrap_torch_tensor, convert_layout, FP4
44
from triton_kernels.tensor_details.layout import HopperMXScaleLayout, HopperMXValueLayout
55
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
@@ -19,6 +19,7 @@
1919
@pytest.mark.parametrize("trans", [False, True])
2020
@pytest.mark.parametrize("mx_axis", [0, 1])
2121
@pytest.mark.parametrize("mma_version", [2, 3])
22+
@pytest.mark.xfail(condition=not is_cuda(), reason="Only supported on CUDA")
2223
def test_mxfp4_value_roundtrip(shape, trans, mx_axis, mma_version):
2324
x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda")
2425
if trans:
@@ -33,6 +34,7 @@ def test_mxfp4_value_roundtrip(shape, trans, mx_axis, mma_version):
3334
@pytest.mark.parametrize("mx_axis", [0, 1])
3435
@pytest.mark.parametrize("num_warps", [4, 8])
3536
@pytest.mark.parametrize("shape", [(256, 64), (256, 128), (256, 256)])
37+
@pytest.mark.xfail(condition=not is_cuda(), reason="Only supported on CUDA")
3638
def test_mxfp4_scale_roundtrip(shape, mx_axis, num_warps):
3739
x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda")
3840
layout = HopperMXScaleLayout(x.shape, mx_axis=mx_axis, num_warps=num_warps)
@@ -71,8 +73,9 @@ def _upcast_mxfp4_to_bf16(Y, X, XScale, x_stride_m, x_stride_n, x_scale_stride_m
7173
tl.store(Y + offs_y, y)
7274

7375

74-
@pytest.mark.skipif(not is_cuda(), reason="Only supported on cuda")
75-
@pytest.mark.skipif(not cuda_capability_geq(9), reason="Only supported for capability >= 9")
76+
@pytest.mark.xfail(condition=not is_cuda(), reason="Only supported on cuda")
77+
@pytest.mark.skipif(not is_cuda() and not is_xpu(), reason="Only supported on cuda")
78+
@pytest.mark.skipif(is_cuda() and not cuda_capability_geq(9), reason="Only supported for capability >= 9")
7679
def test_upcast_mxfp4_to_bf16():
7780
mx_axis = 0
7881
num_warps = 4

0 commit comments

Comments
 (0)