Skip to content

Commit 420c290

Browse files
authored
Use device fixture for more tests in test_core.py (#5885)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 2bc85dc commit 420c290

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

python/test/unit/language/test_core.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2441,7 +2441,7 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const
24412441
negative_config = [('cumsum', 'float32', (32, 32), -1, False, 4)]
24422442

24432443

2444-
def test_sum_dtype():
2444+
def test_sum_dtype(device):
24452445

24462446
@triton.jit
24472447
def kernel_dtype(out_ptr, init, in_dtype: tl.constexpr, out_dtype: tl.constexpr):
@@ -2461,7 +2461,7 @@ def kernel_default_float(out_ptr):
24612461
x = tl.sum(x)
24622462
tl.store(out_ptr, x)
24632463

2464-
out = torch.empty(1, dtype=torch.int32, device='cuda')
2464+
out = torch.empty(1, dtype=torch.int32, device=device)
24652465
kernel_dtype[(1, )](out, init=1, in_dtype=tl.int1, out_dtype=None)
24662466
assert out[0] == 32 * 32
24672467

@@ -2477,9 +2477,9 @@ def kernel_default_float(out_ptr):
24772477
kernel_default_int[(1, )](out)
24782478
assert out[0] == 32 * 32
24792479

2480-
out = torch.empty(1, dtype=torch.bfloat16, device='cuda')
2480+
out = torch.empty(1, dtype=torch.bfloat16, device=device)
24812481
kernel_default_float[(1, )](out)
2482-
torch.testing.assert_close(out[0], torch.tensor(32 * 32, dtype=torch.bfloat16, device='cuda'))
2482+
torch.testing.assert_close(out[0], torch.tensor(32 * 32, dtype=torch.bfloat16, device=device))
24832483

24842484

24852485
@triton.jit
@@ -2675,16 +2675,16 @@ def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr):
26752675

26762676

26772677
@pytest.mark.parametrize("M, N", [(1, 64), (2, 32), (4, 16), (8, 8), (16, 4), (32, 2), (64, 1)])
2678-
def test_scan_1d(M, N):
2678+
def test_scan_1d(M, N, device):
26792679

26802680
@triton.jit
26812681
def scan_kernel(out_ptr, in_ptr, M: tl.constexpr, N: tl.constexpr):
26822682
input = tl.load(in_ptr + tl.arange(0, M))
26832683
output = tl.cumsum(input).reshape([1, M]).broadcast_to([N, M])
26842684
tl.store(out_ptr + tl.arange(0, M * N), output.reshape([M * N]))
26852685

2686-
x = torch.randint(-100, 100, (M, ), dtype=torch.int32, device='cuda')
2687-
output = torch.empty(M * N, dtype=torch.int32, device='cuda')
2686+
x = torch.randint(-100, 100, (M, ), dtype=torch.int32, device=device)
2687+
output = torch.empty(M * N, dtype=torch.int32, device=device)
26882688

26892689
scan_kernel[(1, )](output, x, M, N)
26902690

@@ -4813,14 +4813,14 @@ def kernel():
48134813

48144814

48154815
@pytest.mark.interpreter
4816-
def test_tma_load_block_shape_err():
4816+
def test_tma_load_block_shape_err(device):
48174817

48184818
@triton.jit
48194819
def kernel(ptr):
48204820
desc = tl._experimental_make_tensor_descriptor(ptr, [128, 128], [128, 1], [1, 32])
48214821
desc.load([0, 0])
48224822

4823-
input = torch.empty((128, 128), dtype=torch.int32, device='cuda')
4823+
input = torch.empty((128, 128), dtype=torch.int32, device=device)
48244824
errc = triton.CompilationError if not is_interpreter() else InterpreterError
48254825
with pytest.raises(errc) as e:
48264826
kernel[(1, )](input)
@@ -4829,14 +4829,14 @@ def kernel(ptr):
48294829

48304830

48314831
@pytest.mark.interpreter
4832-
def test_tma_store_block_shape_err():
4832+
def test_tma_store_block_shape_err(device):
48334833

48344834
@triton.jit
48354835
def kernel(ptr):
48364836
desc = tl._experimental_make_tensor_descriptor(ptr, [128, 128], [128, 1], [8, 8])
48374837
desc.store([0, 0], tl.zeros((1, 32), dtype=tl.int16))
48384838

4839-
input = torch.empty((128, 128), dtype=torch.int16, device='cuda')
4839+
input = torch.empty((128, 128), dtype=torch.int16, device=device)
48404840
errc = triton.CompilationError if not is_interpreter() else InterpreterError
48414841
with pytest.raises(errc) as e:
48424842
kernel[(1, )](input)

0 commit comments

Comments
 (0)