Skip to content

Commit 24986bd

Browse files
Sync test_core.py from upstream (#4741)
Pass rate: 99.04%->98.46%
2 parents b361546 + 60bba60 commit 24986bd

File tree

1 file changed

+71
-22
lines changed

1 file changed

+71
-22
lines changed

python/test/unit/language/test_core.py

Lines changed: 71 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,6 +1580,7 @@ def kernel(X, Y, Z):
15801580
@pytest.mark.parametrize(
15811581
"op, dtype_x_str, mode, sem",
15821582
itertools.chain.from_iterable([[
1583+
('add', 'bfloat16', mode, sem),
15831584
('add', 'float16', mode, sem),
15841585
('add', 'uint32', mode, sem),
15851586
('add', 'int32', mode, sem),
@@ -1609,6 +1610,8 @@ def test_atomic_rmw(op, dtype_x_str, mode, sem, device):
16091610
pytest.xfail("Only test atomic bfloat16/float16 ops on GPU")
16101611
if "uint" in dtype_x_str and mode in ["min_neg", "all_neg"]:
16111612
pytest.xfail("uint cannot be negative")
1613+
if is_xpu() and dtype_x_str == 'bfloat16':
1614+
pytest.skip("bfloat16 not yet supported for xpu")
16121615

16131616
n_programs = 5
16141617

@@ -1623,12 +1626,14 @@ def kernel(X, Z):
16231626
sem_arg = sem if sem is None else f'"{sem}"'
16241627
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x, sem={sem_arg})'})
16251628
numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op]
1626-
max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min
1627-
min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max
1629+
max_neutral = float('-inf') if dtype_x_str in float_dtypes_with_bfloat16 else np.iinfo(getattr(np, dtype_x_str)).min
1630+
min_neutral = float('inf') if dtype_x_str in float_dtypes_with_bfloat16 else np.iinfo(getattr(np, dtype_x_str)).max
16281631
neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
16291632

16301633
# triton result
16311634
rs = RandomState(17)
1635+
dst_type = 'bfloat16' if (dtype_x_str == 'bfloat16') else None
1636+
dtype_x_str = 'float32' if (dtype_x_str == 'bfloat16') else dtype_x_str
16321637
x = np.array([2**i for i in range(n_programs)], dtype=getattr(np, dtype_x_str))
16331638
if mode == 'all_neg':
16341639
x = -np.abs(x)
@@ -1640,12 +1645,17 @@ def kernel(X, Z):
16401645
if mode == 'max_pos':
16411646
idx = rs.randint(n_programs, size=(1, )).item()
16421647
x[idx] = np.max(np.abs(x)) + 1
1643-
x_tri = to_triton(x, device=device)
1648+
x_tri = to_triton(x, device=device, dst_type=dst_type)
16441649

1645-
z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device)
1650+
z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device, dst_type=dst_type)
16461651
h = kernel[(n_programs, )](x_tri, z_tri)
16471652
# torch result
1648-
z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
1653+
if dst_type == 'bfloat16':
1654+
z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
1655+
# trunc mantissa for a fair comparison of accuracy
1656+
z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
1657+
else:
1658+
z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
16491659
# compare
16501660
exact = op not in ['add']
16511661
if exact:
@@ -1656,6 +1666,12 @@ def kernel(X, Z):
16561666
if not is_cuda():
16571667
return
16581668

1669+
# atom.add.bf16 is unsupported prior to Hopper so instead we generate an
1670+
# atom.cas add loop on Ampere and prior
1671+
if dst_type == 'bfloat16' and torch.cuda.get_device_capability()[0] < 9:
1672+
assert f"atom.{sem_str}.global.cas" in h.asm["ptx"]
1673+
return
1674+
16591675
assert f"atom.global.gpu.{sem_str}" in h.asm["ptx"]
16601676

16611677

@@ -1680,10 +1696,12 @@ def kernel(X):
16801696
for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)]
16811697
for axis in [0, 1]
16821698
for num_ctas in num_ctas_list
1683-
for dtype_x_str in ['float16', 'float32', 'uint64', 'int64', 'float64']
1699+
for dtype_x_str in ['bfloat16', 'float16', 'float32', 'uint64', 'int64', 'float64']
16841700
for check_return_val in ([True, False] if is_hip() else [True])])
16851701
def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, check_return_val, device):
16861702
check_type_supported(dtype_x_str, device)
1703+
if is_xpu() and dtype_x_str == 'bfloat16':
1704+
pytest.skip("bfloat16 not yet supported for xpu")
16871705
shape0, shape1 = shape
16881706
# triton kernel
16891707

@@ -1694,14 +1712,14 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const
16941712
off1 = tl.arange(0, SHAPE1)
16951713
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
16961714

1697-
if DTYPE == tl.float16:
1715+
if DTYPE == tl.float16 or DTYPE == tl.bfloat16:
16981716
# sum can have bad numerics when accumulating in float16.
16991717
# if we're dealing with float16, do the sum in float32.
17001718
x = x.to(tl.float32)
17011719

17021720
z = tl.sum(x, axis=AXIS)
17031721

1704-
if DTYPE == tl.float16:
1722+
if DTYPE == tl.float16 or DTYPE == tl.bfloat16:
17051723
z = z.to(DTYPE)
17061724

17071725
if AXIS == 1:
@@ -1717,7 +1735,7 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const
17171735
x = numpy_random((shape0, shape1), dtype_str=dtype_x_str, rs=rs)
17181736
z_shape = (shape0, ) if axis == 1 else (shape1, )
17191737
z = numpy_random(z_shape, dtype_str=dtype_x_str, rs=rs)
1720-
old = np.zeros(z_shape, dtype=getattr(np, dtype_x_str))
1738+
old = np.zeros(z_shape, dtype=z.dtype)
17211739
# reference results
17221740
if x.dtype == np.float16:
17231741
# do the sum in float32 to reduce numerical variation
@@ -1726,17 +1744,31 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const
17261744
z_ref = z + np.sum(x, axis=axis, keepdims=False)
17271745
old_ref = np.copy(z)
17281746
# triton result
1729-
x_tri = to_triton(x, device=device)
1730-
z_tri = to_triton(z, device=device)
1731-
old_tri = to_triton(old, device=device)
1747+
x_tri = to_triton(x, device=device, dst_type=dtype_x_str)
1748+
z_tri = to_triton(z, device=device, dst_type=dtype_x_str)
1749+
old_tri = to_triton(old, device=device, dst_type=dtype_x_str)
17321750

17331751
def torch_to_triton_dtype(t):
1752+
if t == torch.bfloat16:
1753+
return tl.bfloat16
17341754
if t == torch.float16:
17351755
return tl.float16
17361756
return None
17371757

17381758
kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, torch_to_triton_dtype(x_tri.dtype), check_return_val,
17391759
num_ctas=num_ctas)
1760+
1761+
if dtype_x_str == 'bfloat16':
1762+
# trunc mantissa for a fair comparison of accuracy
1763+
z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
1764+
old_ref = (old_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
1765+
# mantissa trunc is not enough, bump up the relative tolerance as well
1766+
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.5)
1767+
# check return vals, but use assert_allclose for bf16
1768+
if check_return_val:
1769+
np.testing.assert_allclose(old_ref, to_numpy(old_tri), rtol=0.5)
1770+
return
1771+
17401772
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
17411773
if check_return_val:
17421774
np.testing.assert_equal(old_ref, to_numpy(old_tri))
@@ -1746,8 +1778,11 @@ def torch_to_triton_dtype(t):
17461778
@pytest.mark.parametrize("size, num_ctas, dtype_x_str", [(size, num_ctas, dtype_x_str)
17471779
for size in [2, 4, 8, 32, 64, 128]
17481780
for num_ctas in num_ctas_list
1749-
for dtype_x_str in ['float16', 'float32']])
1781+
for dtype_x_str in ['bfloat16', 'float16', 'float32']])
17501782
def test_tensor_atomic_add_non_exclusive_offset(size, num_ctas, dtype_x_str, device):
1783+
check_type_supported(dtype_x_str, device)
1784+
if is_xpu() and dtype_x_str == 'bfloat16':
1785+
pytest.skip("bfloat16 not yet supported for xpu")
17511786

17521787
@triton.jit
17531788
def kernel(X, val, NUM: tl.constexpr):
@@ -1757,8 +1792,9 @@ def kernel(X, val, NUM: tl.constexpr):
17571792
tl.atomic_add(X + offset // 2, val)
17581793

17591794
shape = (size // 2, size)
1760-
x = torch.zeros(shape, dtype=getattr(torch, dtype_x_str), device=device)
1761-
val = torch.randn((size**2), dtype=getattr(torch, dtype_x_str), device=device)
1795+
dtype = getattr(torch, dtype_x_str)
1796+
x = torch.zeros(shape, dtype=dtype, device=device)
1797+
val = torch.randn((size**2), dtype=dtype, device=device)
17621798
kernel[(1, )](x, val, size, num_warps=1, num_ctas=num_ctas)
17631799
ref = val[0::2] + val[1::2]
17641800
torch.testing.assert_close(ref, x.reshape(math.prod(shape)))
@@ -1768,9 +1804,11 @@ def kernel(X, val, NUM: tl.constexpr):
17681804
@pytest.mark.parametrize("size, num_ctas, dtype_x_str", [(size, num_ctas, dtype_x_str)
17691805
for size in [2, 4, 8, 32, 64, 128]
17701806
for num_ctas in num_ctas_list
1771-
for dtype_x_str in ['float16', 'float32']])
1807+
for dtype_x_str in ['bfloat16', 'float16', 'float32']])
17721808
def test_tensor_atomic_add_shift_1(size, num_ctas, dtype_x_str, device):
17731809
check_type_supported(dtype_x_str, device)
1810+
if is_xpu() and dtype_x_str == 'bfloat16':
1811+
pytest.skip("bfloat16 not yet supported for xpu")
17741812

17751813
@triton.jit
17761814
def kernel(X, val, NUM: tl.constexpr):
@@ -1801,12 +1839,15 @@ def kernel(X, val, NUM: tl.constexpr):
18011839
for idx_order in ['increase', 'decrease', 'random_no_duplication', 'random']
18021840
for mask_step in range(1, 5)
18031841
for num_ctas in num_ctas_list
1804-
for dtype_x_str in ['float16', 'float32']])
1842+
for dtype_x_str in ['bfloat16', 'float16', 'float32']])
18051843
def test_tensor_atomic_add_access_patterns(shape, idx_order, mask_step, num_ctas, dtype_x_str, device):
18061844
check_type_supported(dtype_x_str, device)
18071845
if is_interpreter():
18081846
pytest.xfail("not supported in the interpreter")
18091847

1848+
if is_xpu() and dtype_x_str == 'bfloat16':
1849+
pytest.skip("bfloat16 not yet supported for xpu")
1850+
18101851
@triton.jit
18111852
def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.constexpr):
18121853
xoffset = tl.program_id(0) * XBLOCK
@@ -1829,8 +1870,9 @@ def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.const
18291870
if idx_order == 'random':
18301871
idx = torch.randint(0, shape1, size=(shape0, shape1), device=device)
18311872

1832-
val = torch.randn((shape0, shape1), dtype=getattr(torch, dtype_x_str), device=device)
1833-
dst = torch.randn((shape0, shape1), dtype=getattr(torch, dtype_x_str), device=device)
1873+
dtype = getattr(torch, dtype_x_str)
1874+
val = torch.randn((shape0, shape1), dtype=dtype, device=device)
1875+
dst = torch.randn((shape0, shape1), dtype=dtype, device=device)
18341876

18351877
dst_ref = dst.clone()
18361878

@@ -1842,6 +1884,11 @@ def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.const
18421884
cnt += 1
18431885

18441886
kernel[(1, )](val, idx, dst, shape0, shape1, mask_step, 64, num_ctas=num_ctas)
1887+
1888+
if dtype_x_str == 'bfloat16':
1889+
torch.testing.assert_close(dst_ref, dst, rtol=0.1, atol=0.1)
1890+
return
1891+
18451892
np.testing.assert_allclose(to_numpy(dst_ref), to_numpy(dst), atol=1e-2)
18461893

18471894

@@ -3248,6 +3295,8 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov
32483295
pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape")
32493296
if reduce_op == "sum" and dtype_str == "float16" and M * N > 1024:
32503297
pytest.xfail("Skipping sum reduction on float16 due to accuracy issues")
3298+
if isinstance(src_layout, LinearLayout) and THREADS_PER_WARP != (1 << len(src_layout.lane)):
3299+
pytest.xfail(f"Skipping. This LinearLayout assumes {1 << len(src_layout.lane)} threads per warp")
32513300

32523301
if isinstance(src_layout, MmaLayout) and src_layout.version == 3:
32533302
src_layout.instr_shape[2] = 16 if dtype_str == "float16" else 8
@@ -7646,11 +7695,11 @@ def inject_layout(ir, src: torch.Tensor, axis, indices: torch.Tensor, src_layout
76467695
pat += str(axis)
76477696
pat += r" : i32, efficient_layout} : \(tensor\<"
76487697
pat += src_spec
7649-
pat += r", (#[a-z]+[0-9]*)\>, tensor\<"
7698+
pat += r", (#[a-z]+[0-9]+)\>, tensor\<"
76507699
pat += indices_spec
7651-
pat += r", (#[a-z]+[0-9]*)\>\) -> tensor\<"
7700+
pat += r", (#[a-z]+[0-9]+)\>\) -> tensor\<"
76527701
pat += output_spec
7653-
pat += r", (#[a-z]+[0-9]*)\>"
7702+
pat += r", (#[a-z]+[0-9]+)\>"
76547703

76557704
repl = r"""
76567705
%src = ttg.convert_layout \2 : tensor<""" + src_spec + r""", \4> -> tensor<""" + src_spec + r""", #src_layout>

0 commit comments

Comments
 (0)