Skip to content

Commit 2efb067

Browse files
Generalize unit tests for different backends (#5576)
Generalize unit tests for different backends, for example not hard coding `device` with `cuda`. --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent 74de6b4 commit 2efb067

File tree

7 files changed

+24
-20
lines changed

7 files changed

+24
-20
lines changed

python/test/regression/test_cast_matmul.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,11 @@ def matmul_kernel(A, B, C, M, N, K, #
8585
for w in input_dtypes
8686
for x in input_dtypes #
8787
for o in out_dtypes])
88-
def test_cast_matmul(M, K, N, BLOCK_K, BLOCK_M, w_dtype, x_dtype, out_dtype):
88+
def test_cast_matmul(M, K, N, BLOCK_K, BLOCK_M, w_dtype, x_dtype, out_dtype, device):
8989
if x_dtype == w_dtype:
9090
pytest.skip("skip the same input dtype")
9191
if is_hip() and BLOCK_M == 64 and w_dtype in ["float8_e5m2", "float8_e4m3fnuz"]:
9292
pytest.skip("skip due to bug on HIP path")
93-
device = torch.cuda.current_device()
9493
x_dtype: torch.dtype = getattr(torch, x_dtype)
9594
w_dtype: torch.dtype = getattr(torch, w_dtype)
9695

python/test/unit/language/test_conversions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device):
333333
if src_dtype != 'float32' and is_cuda() and torch.cuda.get_device_capability(0) < (9, 0):
334334
pytest.skip("non-float32 downcast tests only supported on NVGPU with compute capability 9.0+")
335335

336-
if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and (is_hip() or torch.cuda.get_device_capability(0) < (9, 0)):
336+
if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and (is_hip() or torch.cuda.is_available() and torch.cuda.get_device_capability(0) < (9, 0)):
337337
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+")
338338

339339
if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and (is_cuda() or not is_hip_mi300()):

python/test/unit/language/test_core.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,7 @@ def __str__(self):
217217

218218

219219
def is_layout_applicable(layout) -> bool:
220-
common_layouts = [BlockedLayout, SharedLayout]
221-
if layout in common_layouts:
220+
if isinstance(layout, (BlockedLayout, SharedLayout)):
222221
return True
223222
elif isinstance(layout, SliceLayout):
224223
return is_layout_applicable(layout.parent)
@@ -1447,6 +1446,7 @@ def kernel(X, Y, Z):
14471446
for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos']
14481447
for sem in [None, 'acquire', 'release', 'acq_rel', 'relaxed']]))
14491448
def test_atomic_rmw(op, dtype_x_str, mode, sem, device):
1449+
check_type_supported(dtype_x_str, device)
14501450
if is_interpreter():
14511451
if dtype_x_str == 'float16':
14521452
pytest.skip("Only test atomic float16 ops on GPU")
@@ -1523,6 +1523,7 @@ def kernel(X):
15231523
for num_ctas in num_ctas_list
15241524
for dtype_x_str in ['float16', 'float32', 'uint64', 'int64', 'float64']])
15251525
def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, device):
1526+
check_type_supported(dtype_x_str, device)
15261527
shape0, shape1 = shape
15271528
# triton kernel
15281529

@@ -2874,7 +2875,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov
28742875

28752876

28762877
@pytest.mark.parametrize("M", [32, 64, 128, 256])
2877-
@pytest.mark.parametrize("src_layout", layouts)
2878+
@pytest.mark.parametrize("src_layout", filter_layouts(layouts))
28782879
def test_store_op(M, src_layout, device, tmp_path: pathlib.Path):
28792880

28802881
ir = f"""
@@ -3807,7 +3808,7 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_
38073808

38083809
if B == 8 and M == 64 and in_dtype_str == "float32" and out_dtype_str == "float32":
38093810
if not is_interpreter() and triton.runtime.driver.active.utils.get_device_properties(
3810-
torch.cuda.current_device())["max_shared_mem"] < 131072:
3811+
triton.runtime.driver.active.get_current_device())["max_shared_mem"] < 131072:
38113812
pytest.skip(
38123813
"Skipping tests with B = 8, M = 64, in_type = float32, out_type = float32 due to insufficient shared memory (less than 128 KB per SM) on this GPU."
38133814
)
@@ -6550,7 +6551,7 @@ def gather_test_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0:
65506551
([128, 64], [256, 64], 0),
65516552
([128, 64], [128, 128], 1),
65526553
])
6553-
def test_gather(src_shape, indices_shape, axis):
6554+
def test_gather(src_shape, indices_shape, axis, device):
65546555

65556556
def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor):
65566557
output = torch.empty(indices.shape, dtype=src.dtype, device=src.device)
@@ -6562,8 +6563,8 @@ def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor):
65626563

65636564
return output
65646565

6565-
src = torch.randn(src_shape, device='cuda')
6566-
indices = torch.randint(0, src.shape[axis], indices_shape, device='cuda')
6566+
src = torch.randn(src_shape, device=device)
6567+
indices = torch.randint(0, src.shape[axis], indices_shape, device=device)
65676568
ref = torch.gather(src, axis, indices)
65686569
result = triton_gather(src, axis, indices)
65696570
torch.testing.assert_close(result, ref, rtol=0, atol=0)
@@ -6580,7 +6581,8 @@ def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor):
65806581
"linear<{register = [[0, 2], [32, 0], [0, 32], [2, 0], [0, 16], [64, 0], [128, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>"
65816582
),
65826583
])
6583-
def test_gather_warp_shuffle(src_shape, indices_shape, axis, src_layout, indices_layout, tmp_path: pathlib.Path):
6584+
def test_gather_warp_shuffle(src_shape, indices_shape, axis, src_layout, indices_layout, tmp_path: pathlib.Path,
6585+
device):
65846586
if is_hip():
65856587
pytest.skip("warp-local gather has issues on HIP")
65866588

@@ -6623,8 +6625,8 @@ def inject_layout(ir, src: torch.Tensor, axis, indices: torch.Tensor, src_layout
66236625
\1 = ttg.convert_layout %out : tensor<""" + output_spec + r""", #idx_layout> -> tensor<""" + output_spec + r""", \6>"""
66246626
return re.sub(pat, repl, ir)
66256627

6626-
src = torch.randn(src_shape, device='cuda')
6627-
indices = torch.randint(0, src.shape[axis], indices_shape, device='cuda')
6628+
src = torch.randn(src_shape, device=device)
6629+
indices = torch.randint(0, src.shape[axis], indices_shape, device=device)
66286630
ref = torch.gather(src, axis, indices)
66296631

66306632
output, compiled = prepare_kernel(src, axis, indices)

python/test/unit/language/test_random.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def kernel(X, seed):
203203

204204
x = torch.empty(1, dtype=torch.float32, device=device)
205205
with pytest.raises(triton.compiler.errors.CompilationError):
206-
seed0 = torch.zeros(1, dtype=torch.int32, device="cuda")
206+
seed0 = torch.zeros(1, dtype=torch.int32, device=device)
207207
kernel[(1, )](x, seed0)
208208
with pytest.raises(triton.compiler.errors.CompilationError):
209209
seed1 = 2.3

python/test/unit/language/test_tuple.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def _tuple_index(_0, Ptrs, _1: tl.constexpr, values, _2, _3: tl.constexpr, _4):
2525

2626

2727
@pytest.mark.parametrize("size", [0, 1, 2, 3, 4])
28-
def test_index(size, device="cuda"):
28+
def test_index(size, device):
2929
vals = tuple([i + 1 for i in range(size)])
3030
rets = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in vals])
3131
_tuple_index[(1, )](0, rets, 0, vals, 0, 0, 0)
@@ -51,7 +51,7 @@ def _tuple_assign(XPtrs, YPtrs, values):
5151
tl.store(Y[2], y[2])
5252

5353

54-
def test_assign(device="cuda"):
54+
def test_assign(device):
5555
vals = (2., 3.)
5656
x = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(2)])
5757
y = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(3)])
@@ -91,7 +91,7 @@ def _tuple_serialize(Ptr, N1, tuple1, cst1: tl.constexpr, val1, tuple2):
9191
_tuple_fn0(Ptr, 15, (-1, None, tuple1))
9292

9393

94-
def test_serialize(device="cuda"):
94+
def test_serialize(device):
9595
x0 = torch.tensor([8], dtype=torch.int32, device=device)
9696
x1 = torch.tensor([12], dtype=torch.int32, device=device)
9797
y0 = torch.tensor([10], dtype=torch.int32, device=device)
@@ -133,7 +133,7 @@ def _namedtuple_kernel(closure, _X, Y, BLOCK_M: tl.constexpr, BLOCK_N: tl.conste
133133
tl.store(Ys, y, mask=_namedtuple_mask_func(Y, BLOCK_M, BLOCK_N))
134134

135135

136-
def test_namedtuple(device="cuda"):
136+
def test_namedtuple(device):
137137
x = torch.randn((32, 32), dtype=torch.float32, device=device)
138138
y = torch.empty((16, 16), dtype=torch.float32, device=device)
139139
a = torch.tensor([5.2], dtype=torch.float32, device=device)

python/test/unit/runtime/test_autotuner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ def do_bench(kernel_call, quantiles):
1111

1212
@pytest.mark.parametrize('use_cuda_graph', [False, True])
1313
def test_kwargs(use_cuda_graph: bool, device: str):
14+
if use_cuda_graph and not torch.cuda.is_available():
15+
pytest.xfail("CUDA is not available")
16+
1417
M, N = 1024, 16
1518
src = torch.randn(M * N, device=device)
1619
dst = torch.empty(M * N, device=device)

python/test/unit/runtime/test_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ def cache_hook(*args, **kwargs):
577577
assert pointer_range_32 == [(0, )]
578578

579579

580-
def test_function_arguments():
580+
def test_function_arguments(device):
581581

582582
@triton.jit
583583
def func1():
@@ -601,7 +601,7 @@ def kernel(Y, fn: tl.constexpr, fn_args):
601601

602602
JITFunction.cache_hook = None
603603
JITFunction.compiled_hook = None
604-
y = torch.zeros((5, ), dtype=torch.int32, device="cuda")
604+
y = torch.zeros((5, ), dtype=torch.int32, device=device)
605605
kernel[(1, )](y[0], func1, tuple())
606606
kernel[(1, )](y[1], func2, tuple())
607607
kernel[(1, )](y[2], func3, (3, ))

0 commit comments

Comments
 (0)