Skip to content

Commit 255c7f6

Browse files
authored
[tensor-descriptor]: Add test_tensor_descriptor_load_nd and other tests (#4141)
Add tests `test_tensor_descriptor_load_nd`, `test_tensor_descriptor_store_nd`, `kernel_make_tensor_descriptor_loop_carried`, `test_tensor_descriptor_batched_gemm_2d_tma`, `batched_gemm_2d_tma_kernel` --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 1d74db2 commit 255c7f6

File tree

1 file changed

+313
-1
lines changed

1 file changed

+313
-1
lines changed

python/test/unit/intel/test_tensor_descriptor.py

Lines changed: 313 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
203203
def test_tensor_descriptor_store3d(dtype_str, K_BLOCK):
204204

205205
if dtype_str == 'bfloat16':
206-
return pytest.skip("FIXME: bfloat16 test fails verification")
206+
return pytest.skip("FIXME: issue #4137")
207207

208208
@triton.jit
209209
def kernel(out_ptr, a_ptr, M, N, K, stride_m, stride_n, stride_k, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr,
@@ -248,6 +248,138 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
248248
torch.testing.assert_close(expect, actual)
249249

250250

251+
@pytest.mark.parametrize("dtype_str", tma_dtypes)
252+
@pytest.mark.parametrize("num_ctas", [1])
253+
@pytest.mark.parametrize("ndim", [1, 2, 3, 4, 5])
254+
@pytest.mark.parametrize("INNER_BLOCK", [16, 32, 64, 128])
255+
def test_tensor_descriptor_load_nd(dtype_str, num_ctas, ndim, INNER_BLOCK):
256+
257+
if ndim not in [1] or dtype_str not in ["uint16", "uint32"]:
258+
return pytest.skip("FIXME: issue #4139")
259+
260+
@triton.jit
261+
def kernel(out_ptr, a_ptr, shape, strides, BLOCK_SHAPE):
262+
desc = tl.make_tensor_descriptor(
263+
a_ptr,
264+
shape=shape,
265+
strides=strides,
266+
block_shape=BLOCK_SHAPE,
267+
)
268+
ndim: tl.constexpr = len(BLOCK_SHAPE)
269+
270+
offs = (0, ) * ndim
271+
block = desc.load(offs)
272+
273+
idx = tl.full(BLOCK_SHAPE, 0, tl.int32)
274+
stride = 1
275+
for k in tl.static_range(ndim - 1, -1, -1):
276+
arange = tl.arange(0, BLOCK_SHAPE[k])
277+
for _ in tl.static_range(k):
278+
arange = tl.expand_dims(arange, 0)
279+
for _ in tl.static_range(k + 1, ndim):
280+
arange = tl.expand_dims(arange, -1)
281+
282+
idx += arange * stride
283+
stride *= BLOCK_SHAPE[k]
284+
285+
tl.store(out_ptr + idx, block)
286+
287+
def alloc_fn(size: int, align: int, stream: Optional[int]):
288+
return torch.empty(size, dtype=torch.int8, device="xpu")
289+
290+
triton.set_allocator(alloc_fn)
291+
292+
alloc_shape = (1, 1, 3, 7, INNER_BLOCK)[-ndim:]
293+
inp = to_triton(numpy_random(alloc_shape, dtype_str), device="xpu", dst_type=dtype_str)
294+
inp.data = inp.data[..., :INNER_BLOCK - 3]
295+
296+
if INNER_BLOCK * inp.element_size() < 32:
297+
return pytest.xfail("Invalid last dim size")
298+
299+
BLOCK_SHAPE = (2, 2, 4, 8, INNER_BLOCK)[-ndim:]
300+
out = inp.new_empty(BLOCK_SHAPE)
301+
302+
constexpr_block_shape = tuple(tl.constexpr(v) for v in BLOCK_SHAPE)
303+
kernel[(1, )](out, inp, inp.shape, inp.stride(), constexpr_block_shape, num_ctas=num_ctas)
304+
305+
# Check in-bounds
306+
actual = unwrap_tensor(out)
307+
expect = unwrap_tensor(inp)
308+
idx = [slice(None, s) for s in inp.shape]
309+
torch.testing.assert_close(expect, actual[idx])
310+
311+
# Check out-of-bounds
312+
actual[idx].zero_()
313+
expect = expect.new_zeros(BLOCK_SHAPE)
314+
torch.testing.assert_close(expect, actual)
315+
316+
317+
@pytest.mark.parametrize("dtype_str", tma_dtypes)
318+
@pytest.mark.parametrize("num_ctas", [1])
319+
@pytest.mark.parametrize("ndim", [1, 2, 3, 4, 5])
320+
@pytest.mark.parametrize("INNER_BLOCK", [16, 32, 64, 128])
321+
def test_tensor_descriptor_store_nd(dtype_str, num_ctas, ndim, INNER_BLOCK):
322+
323+
if ndim not in [1]:
324+
return pytest.skip("FIXME: issue #4140")
325+
326+
@triton.jit
327+
def kernel(out_ptr, a_ptr, shape, strides, BLOCK_SHAPE):
328+
desc = tl.make_tensor_descriptor(
329+
out_ptr,
330+
shape=shape,
331+
strides=strides,
332+
block_shape=BLOCK_SHAPE,
333+
)
334+
ndim: tl.constexpr = len(BLOCK_SHAPE)
335+
336+
idx = tl.full(BLOCK_SHAPE, 0, tl.int32)
337+
stride = 1
338+
for k in tl.static_range(ndim - 1, -1, -1):
339+
arange = tl.arange(0, BLOCK_SHAPE[k])
340+
for _ in tl.static_range(k):
341+
arange = tl.expand_dims(arange, 0)
342+
for _ in tl.static_range(k + 1, ndim):
343+
arange = tl.expand_dims(arange, -1)
344+
345+
idx += arange * stride
346+
stride *= BLOCK_SHAPE[k]
347+
348+
block = tl.load(a_ptr + idx)
349+
350+
offs = (0, ) * ndim
351+
desc.store(offs, block)
352+
353+
def alloc_fn(size: int, align: int, stream: Optional[int]):
354+
return torch.empty(size, dtype=torch.int8, device="xpu")
355+
356+
triton.set_allocator(alloc_fn)
357+
358+
BLOCK_SHAPE = (2, 2, 4, 8, INNER_BLOCK)[-ndim:]
359+
inp = to_triton(numpy_random(BLOCK_SHAPE, dtype_str), device="xpu", dst_type=dtype_str)
360+
361+
if INNER_BLOCK * inp.element_size() < 32:
362+
return pytest.xfail("Invalid last dim size")
363+
364+
out = inp.new_empty(BLOCK_SHAPE)
365+
out.data.fill_(-1)
366+
367+
desc_shape = (1, 1, 3, 7, INNER_BLOCK)[-ndim:]
368+
constexpr_block_shape = tuple(tl.constexpr(v) for v in BLOCK_SHAPE)
369+
kernel[(1, )](out, inp, desc_shape, out.stride(), constexpr_block_shape, num_ctas=num_ctas)
370+
371+
# Check in-bounds
372+
actual = unwrap_tensor(out)
373+
expect = unwrap_tensor(inp)
374+
idx = [slice(None, s) for s in desc_shape]
375+
torch.testing.assert_close(expect[idx], actual[idx])
376+
377+
# Check out-of-bounds
378+
actual[idx].fill_(-1)
379+
expect = expect.new_full(BLOCK_SHAPE, -1)
380+
torch.testing.assert_close(expect, actual)
381+
382+
251383
@triton.jit(noinline=False)
252384
def tensor_descriptor_in_function_helper(out_ptr, in_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
253385
in_desc = tl.make_tensor_descriptor(
@@ -465,6 +597,186 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
465597
torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3)
466598

467599

600+
@triton.jit
601+
def kernel_make_tensor_descriptor_loop_carried(a_ptr, M, N, MBLOCK: tl.constexpr, NBLOCK: tl.constexpr):
602+
# Test that descriptors work with
603+
pid = tl.program_id(0)
604+
moffset = MBLOCK * pid
605+
606+
a_desc = tl.make_tensor_descriptor(
607+
a_ptr,
608+
shape=[M, N],
609+
strides=[N, 1],
610+
block_shape=[MBLOCK, NBLOCK],
611+
)
612+
613+
for i in range(0, N, NBLOCK):
614+
assert isinstance(a_desc, tl.tensor_descriptor)
615+
if i % (3 * NBLOCK) == 0:
616+
a_desc = tl.make_tensor_descriptor(
617+
a_ptr,
618+
shape=[M, N],
619+
strides=[N, 1],
620+
block_shape=[MBLOCK, NBLOCK],
621+
)
622+
assert isinstance(a_desc, tl.tensor_descriptor)
623+
assert isinstance(a_desc, tl.tensor_descriptor)
624+
a = a_desc.load([moffset, i])
625+
a_desc.store([moffset, i], a + 10)
626+
627+
n = 0
628+
while n < N:
629+
assert isinstance(a_desc, tl.tensor_descriptor)
630+
if n % (3 * NBLOCK) == 0:
631+
assert isinstance(a_desc, tl.tensor_descriptor)
632+
a_desc = tl.make_tensor_descriptor(
633+
a_ptr,
634+
shape=[M, N],
635+
strides=[N, 1],
636+
block_shape=[MBLOCK, NBLOCK],
637+
)
638+
assert isinstance(a_desc, tl.tensor_descriptor)
639+
a = a_desc.load([moffset, n])
640+
a_desc.store([moffset, n], a + 5)
641+
642+
n += NBLOCK
643+
644+
645+
@pytest.mark.interpreter
646+
def test_make_tensor_descriptor_loop_carried():
647+
return pytest.skip("FIXME: issue #4132")
648+
649+
device = "xpu"
650+
M, N = 64, 512
651+
torch.manual_seed(42)
652+
A = torch.randn((M, N), dtype=torch.float32, device=device)
653+
MBLOCK, NBLOCK = 8, 128
654+
grid = (triton.cdiv(M, MBLOCK), )
655+
656+
def alloc_fn(size: int, align: int, stream: Optional[int]):
657+
assert size == 128 * grid[0]
658+
assert align == 128
659+
assert stream == 0
660+
return torch.empty(size, dtype=torch.int8, device="xpu")
661+
662+
triton.set_allocator(alloc_fn)
663+
664+
ref_out = A + 15
665+
kernel_make_tensor_descriptor_loop_carried[grid](
666+
A,
667+
M,
668+
N,
669+
MBLOCK,
670+
NBLOCK,
671+
)
672+
torch.testing.assert_close(ref_out, A)
673+
674+
675+
@triton.jit
676+
def batched_gemm_2d_tma_kernel(a_ptr, b_ptr, c_ptr, #
677+
B, M, N, K, #
678+
dtype: tl.constexpr, #
679+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
680+
NUM_SMS: tl.constexpr):
681+
start_pid = tl.program_id(axis=0)
682+
num_tiles_m = tl.cdiv(M, BLOCK_M)
683+
num_tiles_n = tl.cdiv(N, BLOCK_N)
684+
k_tiles = tl.cdiv(K, BLOCK_K)
685+
num_tiles_per_batch = num_tiles_m * num_tiles_n
686+
num_tiles = B * num_tiles_per_batch
687+
688+
tiles_per_SM = num_tiles // NUM_SMS
689+
if start_pid < num_tiles % NUM_SMS:
690+
tiles_per_SM += 1
691+
692+
tile_id = start_pid - NUM_SMS
693+
ki = -1
694+
695+
tile_m = 0
696+
tile_n = 0
697+
tile_b = 0
698+
699+
offs_m = 0
700+
offs_n = 0
701+
offs_b = 0
702+
703+
a_desc = tl.make_tensor_descriptor(a_ptr + offs_b * (M * K), [M, K], [K, 1], [BLOCK_M, BLOCK_K])
704+
b_desc = tl.make_tensor_descriptor(b_ptr + offs_b * (N * K), [N, K], [K, 1], [BLOCK_N, BLOCK_K])
705+
c_desc = tl.make_tensor_descriptor(c_ptr + offs_b * (M * N), [M, N], [N, 1], [BLOCK_M, BLOCK_N])
706+
707+
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
708+
709+
for _ in range(k_tiles * tiles_per_SM):
710+
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
711+
if ki == 0:
712+
tile_id += NUM_SMS
713+
tile_b = tile_id // num_tiles_per_batch
714+
tile_m = (tile_id // num_tiles_n) % num_tiles_m
715+
tile_n = tile_id % num_tiles_n
716+
717+
offs_b = tile_b
718+
offs_m = tile_m * BLOCK_M
719+
offs_n = tile_n * BLOCK_N
720+
721+
a_desc = tl.make_tensor_descriptor(a_ptr + offs_b * (M * K), [M, K], [K, 1], [BLOCK_M, BLOCK_K])
722+
b_desc = tl.make_tensor_descriptor(b_ptr + offs_b * (N * K), [N, K], [K, 1], [BLOCK_N, BLOCK_K])
723+
c_desc = tl.make_tensor_descriptor(c_ptr + offs_b * (M * N), [M, N], [N, 1], [BLOCK_M, BLOCK_N])
724+
725+
offs_k = ki * BLOCK_K
726+
727+
a = a_desc.load([offs_m, offs_k])
728+
b = b_desc.load([offs_n, offs_k])
729+
accumulator = tl.dot(a, b.T, accumulator)
730+
731+
if ki == k_tiles - 1:
732+
c = accumulator.to(dtype)
733+
734+
c_desc.store([offs_m, offs_n], c)
735+
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
736+
737+
738+
@pytest.mark.interpreter
739+
def test_tensor_descriptor_batched_gemm_2d_tma():
740+
return pytest.skip("FIXME: issue #4132")
741+
742+
device = "xpu"
743+
BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 64
744+
if is_interpreter():
745+
B, M, N, K = 2, BLOCK_M, BLOCK_N, BLOCK_K
746+
else:
747+
B, M, N, K = 2, 1024, 1024, 128
748+
NUM_SMS = 96
749+
num_stages = 3
750+
751+
grid = (min(NUM_SMS, B * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)), )
752+
753+
a = torch.randn((B, M, K), device=device, dtype=torch.float16)
754+
b = torch.randn((B, N, K), device=device, dtype=torch.float16)
755+
c = torch.empty((B, M, N), device=device, dtype=torch.float16)
756+
757+
expect = torch.bmm(a, b.mT)
758+
759+
def alloc_fn(size: int, align: int, stream: Optional[int]):
760+
# TODO: should only need num_stages * 3 descriptors per SM
761+
assert size == 128 * 3 * (num_stages + 1) * grid[0]
762+
assert align == 128
763+
assert stream == 0
764+
return torch.empty(size, dtype=torch.int8, device="xpu")
765+
766+
triton.set_allocator(alloc_fn)
767+
768+
batched_gemm_2d_tma_kernel[grid](
769+
a, b, c, #
770+
B, M, N, K, #
771+
tl.float16, #
772+
BLOCK_M, BLOCK_N, BLOCK_K, #
773+
NUM_SMS, #
774+
num_stages=num_stages, num_warps=8)
775+
torch.xpu.synchronize()
776+
777+
torch.testing.assert_close(c, expect, rtol=1e-3, atol=1e-3)
778+
779+
468780
@triton.jit
469781
def batched_gemm_3d_tma_kernel(a_ptr, b_ptr, c_ptr, #
470782
B, M, N, K, #

0 commit comments

Comments
 (0)