|
4 | 4 |
|
5 | 5 | import triton
|
6 | 6 | import triton.language as tl
|
7 |
| -from triton._internal_testing import is_interpreter, numpy_random, to_triton, unwrap_tensor, tma_dtypes |
| 7 | +from triton._internal_testing import is_interpreter, numpy_random, to_triton, unwrap_tensor, tma_dtypes, to_numpy |
8 | 8 | from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor
|
9 | 9 | from typing import Optional
|
10 |
| -from triton._internal_testing import is_cuda, is_hip |
| 10 | +from triton._internal_testing import is_cuda, is_hip, is_hip_cdna3 |
| 11 | +from triton.tools.tensor_descriptor import TensorDescriptor |
| 12 | +from triton import CompilationError |
11 | 13 |
|
12 | 14 |
|
13 | 15 | @pytest.mark.interpreter
|
@@ -1434,3 +1436,140 @@ def alloc_fn(size: int, align: int, steam):
|
1434 | 1436 |
|
1435 | 1437 | ref = torch_scatter_rows(input, idx, y, BLOCK_Y, X, Y)
|
1436 | 1438 | torch.testing.assert_close(ref, output, atol=0, rtol=0)
|
| 1439 | + |
| 1440 | + |
| 1441 | +NATIVE_SUPPORTED_REDUCE_DTYPES = { |
| 1442 | + "add": {tl.uint32, tl.int32, tl.uint64, tl.float32, tl.float16, tl.bfloat16}, |
| 1443 | + "min": {tl.uint32, tl.int32, tl.uint64, tl.int64, tl.float16, tl.bfloat16}, |
| 1444 | + "max": {tl.uint32, tl.int32, tl.uint64, tl.int64, tl.float16, tl.bfloat16}, |
| 1445 | + "and": {tl.uint32, tl.int32, tl.uint64, tl.int64}, |
| 1446 | + "or": {tl.uint32, tl.int32, tl.uint64, tl.int64}, |
| 1447 | + "xor": {tl.uint32, tl.int32, tl.uint64, tl.int64}, |
| 1448 | +} |
| 1449 | +FALLBACK_SUPPORTED_REDUCE_DTYPES = { |
| 1450 | + "add": {tl.uint32, tl.int32, tl.uint64, tl.float32, tl.float16, tl.bfloat16}, |
| 1451 | + "min": {tl.uint32, tl.int32, tl.uint64, tl.int64}, |
| 1452 | + "max": {tl.uint32, tl.int32, tl.uint64, tl.int64}, |
| 1453 | + "and": {tl.uint32, tl.int32, tl.uint64, tl.int64}, |
| 1454 | + "or": {tl.uint32, tl.int32, tl.uint64, tl.int64}, |
| 1455 | + "xor": {tl.uint32, tl.int32, tl.uint64, tl.int64}, |
| 1456 | +} |
| 1457 | + |
| 1458 | + |
| 1459 | +def min_op(a, b): |
| 1460 | + out = np.minimum(to_numpy(a), to_numpy(b)) |
| 1461 | + return unwrap_tensor(to_triton(out, device=a.device)) |
| 1462 | + |
| 1463 | + |
| 1464 | +def max_op(a, b): |
| 1465 | + out = np.maximum(to_numpy(a), to_numpy(b)) |
| 1466 | + return unwrap_tensor(to_triton(out, device=a.device)) |
| 1467 | + |
| 1468 | + |
| 1469 | +REDUCE_OP = { |
| 1470 | + "add": lambda a, b: unwrap_tensor(a) + unwrap_tensor(b), |
| 1471 | + "min": min_op, |
| 1472 | + "max": max_op, |
| 1473 | + "and": lambda a, b: torch.bitwise_and(unwrap_tensor(a), unwrap_tensor(b)), |
| 1474 | + "or": lambda a, b: torch.bitwise_or(unwrap_tensor(a), unwrap_tensor(b)), |
| 1475 | + "xor": lambda a, b: torch.bitwise_xor(unwrap_tensor(a), unwrap_tensor(b)), |
| 1476 | +} |
| 1477 | + |
| 1478 | +REDUCE_SKIP_HIP_CDNA3 = [ |
| 1479 | + ("min", "int32", 1, 1024), |
| 1480 | + ("max", "int32", 1, 1024), |
| 1481 | + ("add", "bfloat16", 1, 1024), |
| 1482 | +] |
| 1483 | + |
| 1484 | + |
| 1485 | +# TODO: interpreter support |
| 1486 | +# @pytest.mark.interpreter |
| 1487 | +@pytest.mark.parametrize("kind", ["add", "min", "max", "and", "or", "xor"]) |
| 1488 | +@pytest.mark.parametrize("dtype_str", tma_dtypes) |
| 1489 | +@pytest.mark.parametrize("num_ctas", [1, 2]) |
| 1490 | +@pytest.mark.parametrize("descriptor", ["host", "device"]) |
| 1491 | +@pytest.mark.parametrize("M_BLOCK,N_BLOCK", [(2, 16), (8, 16), (8, 32), (8, 128), (512, 32), (1, 1024)]) |
| 1492 | +def test_tensor_descriptor_reduce(kind, descriptor, dtype_str, num_ctas, M_BLOCK, N_BLOCK): |
| 1493 | + is_native = is_cuda() and torch.cuda.get_device_capability()[0] >= 9 |
| 1494 | + if not is_native: |
| 1495 | + if num_ctas != 1: |
| 1496 | + pytest.skip("Multi-CTA not supported") |
| 1497 | + if descriptor == "host": |
| 1498 | + pytest.skip("NYI: Host side tensor descriptor fallback") |
| 1499 | + if is_hip_cdna3() and (kind, dtype_str, M_BLOCK, N_BLOCK) in REDUCE_SKIP_HIP_CDNA3: |
| 1500 | + pytest.skip("Broken on rocm") |
| 1501 | + |
| 1502 | + @triton.jit(debug=True) |
| 1503 | + def kernel(out_desc, out_ptr, a_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr, kind: tl.constexpr): |
| 1504 | + moffset = tl.program_id(0) * M_BLOCK |
| 1505 | + noffset = tl.program_id(1) * N_BLOCK |
| 1506 | + |
| 1507 | + midx = moffset + tl.arange(0, M_BLOCK)[:, None] |
| 1508 | + nidx = noffset + tl.arange(0, N_BLOCK)[None, :] |
| 1509 | + idx = midx * N + nidx |
| 1510 | + |
| 1511 | + val = tl.load(a_ptr + idx) |
| 1512 | + |
| 1513 | + if out_desc is None: |
| 1514 | + desc = tl.make_tensor_descriptor( |
| 1515 | + out_ptr, |
| 1516 | + shape=[M, N], |
| 1517 | + strides=[N, 1], |
| 1518 | + block_shape=[M_BLOCK, N_BLOCK], |
| 1519 | + ) |
| 1520 | + else: |
| 1521 | + desc = out_desc |
| 1522 | + |
| 1523 | + assert desc.shape[0] == M |
| 1524 | + assert desc.shape[1] == N |
| 1525 | + assert desc.strides[0] == N |
| 1526 | + assert desc.strides[1] == 1 |
| 1527 | + assert desc.block_shape == [M_BLOCK, N_BLOCK] |
| 1528 | + if kind == "add": |
| 1529 | + desc.atomic_add([moffset, noffset], val) |
| 1530 | + elif kind == "min": |
| 1531 | + desc.atomic_min([moffset, noffset], val) |
| 1532 | + elif kind == "max": |
| 1533 | + desc.atomic_max([moffset, noffset], val) |
| 1534 | + elif kind == "and": |
| 1535 | + desc.atomic_and([moffset, noffset], val) |
| 1536 | + elif kind == "or": |
| 1537 | + desc.atomic_or([moffset, noffset], val) |
| 1538 | + else: |
| 1539 | + tl.static_assert(kind == "xor") |
| 1540 | + desc.atomic_xor([moffset, noffset], val) |
| 1541 | + |
| 1542 | + M, N = M_BLOCK * 2, N_BLOCK * 2 |
| 1543 | + rs = np.random.RandomState(seed=17) |
| 1544 | + inp = to_triton(numpy_random((M, N), dtype_str, rs), device="cuda", dst_type=dtype_str) |
| 1545 | + out = to_triton(numpy_random((M, N), dtype_str, rs), device="cuda", dst_type=dtype_str) |
| 1546 | + |
| 1547 | + grid_m = M // M_BLOCK |
| 1548 | + grid_n = N // N_BLOCK |
| 1549 | + |
| 1550 | + if descriptor == "host": |
| 1551 | + out_desc = TensorDescriptor.from_tensor(out, [M_BLOCK, N_BLOCK]) |
| 1552 | + else: |
| 1553 | + |
| 1554 | + def alloc_fn(size: int, align: int, stream: Optional[int]): |
| 1555 | + assert size == 128 * (grid_m * grid_n) * num_ctas |
| 1556 | + assert align == 128 |
| 1557 | + assert stream == 0 |
| 1558 | + return torch.empty(size, dtype=torch.int8, device="cuda") |
| 1559 | + |
| 1560 | + triton.set_allocator(alloc_fn) |
| 1561 | + out_desc = None |
| 1562 | + |
| 1563 | + dtype = getattr(tl, dtype_str) |
| 1564 | + native_supported = dtype in NATIVE_SUPPORTED_REDUCE_DTYPES[kind] |
| 1565 | + fallback_supported = dtype in FALLBACK_SUPPORTED_REDUCE_DTYPES[kind] |
| 1566 | + supported = native_supported if is_native else fallback_supported |
| 1567 | + if not supported: |
| 1568 | + exc_type = CompilationError if not native_supported else RuntimeError |
| 1569 | + with pytest.raises(exc_type): |
| 1570 | + kernel[(grid_m, grid_n)](out_desc, out, inp, M, N, M_BLOCK, N_BLOCK, kind, num_ctas=num_ctas) |
| 1571 | + return |
| 1572 | + |
| 1573 | + expect = REDUCE_OP[kind](inp, out) |
| 1574 | + kernel[(grid_m, grid_n)](out_desc, out, inp, M, N, M_BLOCK, N_BLOCK, kind, num_ctas=num_ctas) |
| 1575 | + torch.testing.assert_close(expect, unwrap_tensor(out), check_dtype=False) |
0 commit comments