Skip to content

Commit 49a72f5

Browse files
authored
[Tests] Replace block_ptr with tensor_descriptor (#6846)
This changes tests that aren't specifically testing block_ptr to use tensor_descriptor instead.
1 parent e163113 commit 49a72f5

File tree

4 files changed

+45
-43
lines changed

4 files changed

+45
-43
lines changed

python/test/unit/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,16 @@ def fresh_knobs_except_libraries(monkeypatch):
8888
yield fresh_function()
8989
finally:
9090
reset_function()
91+
92+
93+
@pytest.fixture
94+
def with_allocator():
95+
import triton
96+
from triton.runtime._allocation import NullAllocator
97+
from triton._internal_testing import default_alloc_fn
98+
99+
triton.set_allocator(default_alloc_fn)
100+
try:
101+
yield
102+
finally:
103+
triton.set_allocator(NullAllocator())

python/test/unit/language/test_core.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3528,32 +3528,29 @@ def kernel(In, Out, in_shape1: tl.constexpr, in_shape2: tl.constexpr, ou_shape1:
35283528

35293529
@pytest.mark.interpreter
35303530
@pytest.mark.parametrize("dtype_str", ["int32", "int8"])
3531-
@pytest.mark.parametrize("shape", [(2, 2, 8, 64), (4, 4, 4, 4)])
3531+
@pytest.mark.parametrize("shape", [(2, 2, 8, 64), (4, 4, 4, 16)])
35323532
@pytest.mark.parametrize("perm", list(itertools.permutations([0, 1, 2, 3])))
3533-
def test_trans_4d(dtype_str, shape, perm, device):
3533+
def test_trans_4d(dtype_str, shape, perm, device, with_allocator):
35343534

35353535
@triton.jit
35363536
def kernel(In, Out, #
35373537
in_shape1: tl.constexpr, in_shape2: tl.constexpr, in_shape3: tl.constexpr, in_shape4: tl.constexpr,
35383538
ou_shape1: tl.constexpr, ou_shape2: tl.constexpr, ou_shape3: tl.constexpr, ou_shape4: tl.constexpr,
35393539
trans1: tl.constexpr, trans2: tl.constexpr, trans3: tl.constexpr, trans4: tl.constexpr):
3540-
in_ptr = tl.make_block_ptr(
3540+
in_desc = tl.make_tensor_descriptor(
35413541
base=In,
3542-
shape=(in_shape1, in_shape2, in_shape3, in_shape4),
3543-
strides=(in_shape4 * in_shape3 * in_shape2, in_shape4 * in_shape3, in_shape4, 1),
3544-
offsets=(0, 0, 0, 0),
3545-
block_shape=(in_shape1, in_shape2, in_shape3, in_shape4),
3546-
order=(3, 2, 1, 0),
3542+
shape=[in_shape1, in_shape2, in_shape3, in_shape4],
3543+
strides=[in_shape4 * in_shape3 * in_shape2, in_shape4 * in_shape3, in_shape4, 1],
3544+
block_shape=[in_shape1, in_shape2, in_shape3, in_shape4],
35473545
)
3548-
out_ptr = tl.make_block_ptr(
3546+
out_desc = tl.make_tensor_descriptor(
35493547
base=Out,
3550-
shape=(ou_shape1, ou_shape2, ou_shape3, ou_shape4),
3551-
strides=(ou_shape4 * ou_shape3 * ou_shape2, ou_shape4 * ou_shape3, ou_shape4, 1),
3552-
offsets=(0, 0, 0, 0),
3553-
block_shape=(ou_shape1, ou_shape2, ou_shape3, ou_shape4),
3554-
order=(3, 2, 1, 0),
3548+
shape=[ou_shape1 * ou_shape2 * ou_shape3 * ou_shape4],
3549+
strides=[1],
3550+
block_shape=[ou_shape1 * ou_shape2 * ou_shape3 * ou_shape4],
35553551
)
3556-
tl.store(out_ptr, tl.load(in_ptr).permute((trans1, trans2, trans3, trans4)))
3552+
val = in_desc.load([0, 0, 0, 0]).permute((trans1, trans2, trans3, trans4))
3553+
out_desc.store([0], val.reshape(out_desc.block_shape))
35573554

35583555
input = torch.arange(math.prod(shape), dtype=getattr(torch, dtype_str), device=device).reshape(shape)
35593556
expected = torch.permute(input, perm)
@@ -5145,7 +5142,7 @@ def kernel(ptr):
51455142
assert "Descriptor block shape must have at least 16 bytes" in str(e.value.__cause__)
51465143

51475144

5148-
def test_trans_reshape(device):
5145+
def test_trans_reshape(device, with_allocator):
51495146

51505147
@triton.jit
51515148
def kernel(in_base_ptr, out_base_ptr, IN_SHAPE0: tl.constexpr, IN_SHAPE1: tl.constexpr):

python/test/unit/test_perf_warning.py

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,40 +32,31 @@ def matmul_kernel(
3232
N,
3333
K,
3434
stride_am,
35-
stride_ak,
36-
stride_bk,
3735
stride_bn,
3836
stride_cm,
39-
stride_cn,
4037
):
41-
a_block_ptr = tl.make_block_ptr(
38+
a_desc = tl.make_tensor_descriptor(
4239
base=a_ptr,
43-
shape=(M, K),
44-
strides=(stride_am, stride_ak),
45-
offsets=(0, 0),
46-
block_shape=(32, 128),
47-
order=(1, 0),
40+
shape=[M, K],
41+
strides=[stride_am, 1],
42+
block_shape=[32, 128],
4843
)
49-
b_block_ptr = tl.make_block_ptr(
44+
b_desc = tl.make_tensor_descriptor(
5045
base=b_ptr,
51-
shape=(K, N),
52-
strides=(stride_bk, stride_bn),
53-
offsets=(0, 0),
54-
block_shape=(128, 32),
55-
order=(0, 1),
46+
shape=[K, N],
47+
strides=[stride_bn, 1],
48+
block_shape=[32, 128],
5649
)
57-
c_block_ptr = tl.make_block_ptr(
50+
c_desc = tl.make_tensor_descriptor(
5851
base=c_ptr,
59-
shape=(M, N),
60-
strides=(stride_cm, stride_cn),
61-
offsets=(0, 0),
62-
block_shape=(32, 32),
63-
order=(1, 0),
52+
shape=[M, N],
53+
strides=[stride_cm, 1],
54+
block_shape=[32, 32],
6455
)
65-
a = tl.load(a_block_ptr)
66-
b = tl.load(b_block_ptr)
56+
a = a_desc.load([0, 0])
57+
b = b_desc.load([0, 0]).T
6758
c = tl.dot(a, b)
68-
tl.store(c_block_ptr, c)
59+
c_desc.store([0, 0], c)
6960

7061
signature = {
7162
"a_ptr": "*fp32",
@@ -75,11 +66,8 @@ def matmul_kernel(
7566
"N": "i32",
7667
"K": "i32",
7768
"stride_am": "i32",
78-
"stride_ak": "i32",
79-
"stride_bk": "i32",
8069
"stride_bn": "i32",
8170
"stride_cm": "i32",
82-
"stride_cn": "i32",
8371
}
8472
with enable_diagnostics_context('remarks'):
8573
triton.compile(triton.compiler.ASTSource(

python/triton/_internal_testing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,10 @@ def tma_skip_msg(byval_only=False):
179179
requires_tma = pytest.mark.skipif(not supports_tma(), reason=tma_skip_msg())
180180

181181

182+
def default_alloc_fn(size: int, align: int, _):
183+
return torch.empty(size, dtype=torch.int8, device="cuda")
184+
185+
182186
def unwrap_tensor(t: Union[torch.Tensor, triton.runtime.jit.TensorWrapper]) -> torch.Tensor:
183187
if isinstance(t, triton.runtime.jit.TensorWrapper):
184188
return t.base

0 commit comments

Comments
 (0)