|
21 | 21 | from triton._internal_testing import ( |
22 | 22 | integral_dtypes, |
23 | 23 | int_dtypes, |
| 24 | + str_to_triton_dtype, |
24 | 25 | uint_dtypes, |
25 | 26 | float_dtypes, |
26 | 27 | float_dtypes_with_bfloat16, |
@@ -1641,7 +1642,7 @@ def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr): |
1641 | 1642 | ('float32', 'bfloat16', False, 1024), |
1642 | 1643 | ('bfloat16', 'float32', False, 1024), |
1643 | 1644 | ('float32', 'int32', True, 1024), |
1644 | | - ('float32', 'int1', False, 1024), |
| 1645 | + ('float32', 'bool', False, 1024), |
1645 | 1646 | ('int8', 'bfloat16', False, 1024), |
1646 | 1647 | ] + [(f'uint{x}', f'int{x}', True, 1024) |
1647 | 1648 | for x in [8, 16, 32, 64]] + [(f'int{x}', f'uint{x}', True, 1024) |
@@ -1687,35 +1688,40 @@ def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device): |
1687 | 1688 | # triton kernel |
1688 | 1689 |
|
1689 | 1690 | @triton.jit |
1690 | | - def kernel(X, Z, BITCAST: tl.constexpr, SIZE: tl.constexpr, ARG_HASH: tl.constexpr): |
| 1691 | + def kernel(X, Z, TO_TYPE: tl.constexpr, BITCAST: tl.constexpr, SIZE: tl.constexpr, ARG_HASH: tl.constexpr): |
1691 | 1692 | x_ptr = X + tl.arange(0, SIZE) |
1692 | 1693 | z_ptr = Z + tl.arange(0, SIZE) |
1693 | 1694 | x = tl.load(x_ptr) |
1694 | 1695 |
|
1695 | 1696 | # Depending on the value of ARG_HASH (a "random" number determined by |
1696 | 1697 | # the test parameters), spell the cast one of three different ways. |
1697 | | - if ARG_HASH % 3 == 0: |
| 1698 | + if ARG_HASH % 4 == 0: |
1698 | 1699 | z = x.to(Z.dtype.element_ty, bitcast=BITCAST) |
1699 | | - elif ARG_HASH % 3 == 1: |
| 1700 | + elif ARG_HASH % 4 == 1: |
1700 | 1701 | z = x.cast(Z.dtype.element_ty, bitcast=BITCAST) |
1701 | | - else: |
| 1702 | + elif ARG_HASH % 4 == 2: |
1702 | 1703 | z = tl.cast(x, Z.dtype.element_ty, bitcast=BITCAST) |
| 1704 | + else: |
| 1705 | + z = tl.cast(x, TO_TYPE, bitcast=BITCAST) |
1703 | 1706 |
|
1704 | 1707 | tl.store(z_ptr, z) |
1705 | 1708 |
|
1706 | 1709 | # "Random" number used inside the kernel to determine how we spell the cast. |
1707 | 1710 | # This way we don't have to increase the number of tests. |
1708 | 1711 | arg_hash = hash((dtype_x, dtype_z, bitcast, size, num_ctas)) |
1709 | 1712 |
|
1710 | | - dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_' |
| 1713 | + dtype_z_np = dtype_z if dtype_z != 'bool' else 'bool_' |
1711 | 1714 | # triton result |
1712 | 1715 | if dtype_z.startswith('bfloat'): |
1713 | 1716 | z_tri = torch.empty((size, ), dtype=getattr(torch, dtype_z), device=device) |
1714 | 1717 | elif dtype_z.startswith('float8'): |
1715 | 1718 | z_tri = torch.empty((size, ), dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_z)) |
1716 | 1719 | else: |
1717 | 1720 | z_tri = to_triton(np.empty((size, ), dtype=getattr(np, dtype_z_np)), device=device) |
1718 | | - kernel[(1, )](x_tri, z_tri, BITCAST=bitcast, SIZE=size, ARG_HASH=arg_hash, num_warps=1, num_ctas=num_ctas) |
| 1721 | + |
| 1722 | + dtype_z_tri = str_to_triton_dtype(dtype_z) |
| 1723 | + kernel[(1, )](x_tri, z_tri, TO_TYPE=dtype_z_tri, BITCAST=bitcast, SIZE=size, ARG_HASH=arg_hash, num_warps=1, |
| 1724 | + num_ctas=num_ctas) |
1719 | 1725 | # torch result |
1720 | 1726 | if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat') or dtype_z.startswith( |
1721 | 1727 | 'float8') or dtype_x.startswith('float8'): |
|
0 commit comments