Skip to content

Commit 13dae07

Browse files
Revert partial "[BACKEND] BF16 atomic_add support (#6519)"
This reverts partial commit 236f6b5.
1 parent 629abab commit 13dae07

File tree

1 file changed

+20
-56
lines changed

1 file changed

+20
-56
lines changed

python/test/unit/language/test_core.py

Lines changed: 20 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,7 +1564,6 @@ def kernel(X, Y, Z):
15641564
@pytest.mark.parametrize(
15651565
"op, dtype_x_str, mode, sem",
15661566
itertools.chain.from_iterable([[
1567-
('add', 'bfloat16', mode, sem),
15681567
('add', 'float16', mode, sem),
15691568
('add', 'uint32', mode, sem),
15701569
('add', 'int32', mode, sem),
@@ -1590,8 +1589,8 @@ def kernel(X, Y, Z):
15901589
def test_atomic_rmw(op, dtype_x_str, mode, sem, device):
15911590
check_type_supported(dtype_x_str, device)
15921591
if is_interpreter():
1593-
if dtype_x_str == 'float16' or dtype_x_str == 'bfloat16':
1594-
pytest.xfail("Only test atomic bfloat16/float16 ops on GPU")
1592+
if dtype_x_str == 'float16':
1593+
pytest.xfail("Only test atomic float16 ops on GPU")
15951594

15961595
n_programs = 5
15971596

@@ -1606,14 +1605,12 @@ def kernel(X, Z):
16061605
sem_arg = sem if sem is None else f'"{sem}"'
16071606
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x, sem={sem_arg})'})
16081607
numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op]
1609-
max_neutral = float('-inf') if dtype_x_str in float_dtypes_with_bfloat16 else np.iinfo(getattr(np, dtype_x_str)).min
1610-
min_neutral = float('inf') if dtype_x_str in float_dtypes_with_bfloat16 else np.iinfo(getattr(np, dtype_x_str)).max
1608+
max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min
1609+
min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max
16111610
neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
16121611

16131612
# triton result
16141613
rs = RandomState(17)
1615-
dst_type = 'bfloat16' if (dtype_x_str == 'bfloat16') else None
1616-
dtype_x_str = 'float32' if (dtype_x_str == 'bfloat16') else dtype_x_str
16171614
x = np.array([2**i for i in range(n_programs)], dtype=getattr(np, dtype_x_str))
16181615
if mode == 'all_neg':
16191616
x = -np.abs(x)
@@ -1625,17 +1622,12 @@ def kernel(X, Z):
16251622
if mode == 'max_pos':
16261623
idx = rs.randint(n_programs, size=(1, )).item()
16271624
x[idx] = np.max(np.abs(x)) + 1
1628-
x_tri = to_triton(x, device=device, dst_type=dst_type)
1625+
x_tri = to_triton(x, device=device)
16291626

1630-
z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device, dst_type=dst_type)
1627+
z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device)
16311628
h = kernel[(n_programs, )](x_tri, z_tri)
16321629
# torch result
1633-
if dst_type == 'bfloat16':
1634-
z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
1635-
# trunc mantissa for a fair comparison of accuracy
1636-
z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
1637-
else:
1638-
z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
1630+
z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
16391631
# compare
16401632
exact = op not in ['add']
16411633
if exact:
@@ -1646,12 +1638,6 @@ def kernel(X, Z):
16461638
if not is_cuda():
16471639
return
16481640

1649-
# atom.add.bf16 is unsupported prior to Hopper so instead we generate an
1650-
# atom.cas add loop on Ampere and prior
1651-
if dst_type == 'bfloat16' and torch.cuda.get_device_capability()[0] < 9:
1652-
assert f"atom.{sem_str}.global.cas" in h.asm["ptx"]
1653-
return
1654-
16551641
assert f"atom.global.gpu.{sem_str}" in h.asm["ptx"]
16561642

16571643

@@ -1676,7 +1662,7 @@ def kernel(X):
16761662
for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)]
16771663
for axis in [0, 1]
16781664
for num_ctas in num_ctas_list
1679-
for dtype_x_str in ['bfloat16', 'float16', 'float32', 'uint64', 'int64', 'float64']
1665+
for dtype_x_str in ['float16', 'float32', 'uint64', 'int64', 'float64']
16801666
for check_return_val in ([True, False] if is_hip() else [True])])
16811667
def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, check_return_val, device):
16821668
check_type_supported(dtype_x_str, device)
@@ -1690,14 +1676,14 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const
16901676
off1 = tl.arange(0, SHAPE1)
16911677
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
16921678

1693-
if DTYPE == tl.float16 or DTYPE == tl.bfloat16:
1679+
if DTYPE == tl.float16:
16941680
# sum can have bad numerics when accumulating in float16.
16951681
# if we're dealing with float16, do the sum in float32.
16961682
x = x.to(tl.float32)
16971683

16981684
z = tl.sum(x, axis=AXIS)
16991685

1700-
if DTYPE == tl.float16 or DTYPE == tl.bfloat16:
1686+
if DTYPE == tl.float16:
17011687
z = z.to(DTYPE)
17021688

17031689
if AXIS == 1:
@@ -1713,7 +1699,7 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const
17131699
x = numpy_random((shape0, shape1), dtype_str=dtype_x_str, rs=rs)
17141700
z_shape = (shape0, ) if axis == 1 else (shape1, )
17151701
z = numpy_random(z_shape, dtype_str=dtype_x_str, rs=rs)
1716-
old = np.zeros(z_shape, dtype=z.dtype)
1702+
old = np.zeros(z_shape, dtype=getattr(np, dtype_x_str))
17171703
# reference results
17181704
if x.dtype == np.float16:
17191705
# do the sum in float32 to reduce numerical variation
@@ -1722,31 +1708,17 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const
17221708
z_ref = z + np.sum(x, axis=axis, keepdims=False)
17231709
old_ref = np.copy(z)
17241710
# triton result
1725-
x_tri = to_triton(x, device=device, dst_type=dtype_x_str)
1726-
z_tri = to_triton(z, device=device, dst_type=dtype_x_str)
1727-
old_tri = to_triton(old, device=device, dst_type=dtype_x_str)
1711+
x_tri = to_triton(x, device=device)
1712+
z_tri = to_triton(z, device=device)
1713+
old_tri = to_triton(old, device=device)
17281714

17291715
def torch_to_triton_dtype(t):
1730-
if t == torch.bfloat16:
1731-
return tl.bfloat16
17321716
if t == torch.float16:
17331717
return tl.float16
17341718
return None
17351719

17361720
kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, torch_to_triton_dtype(x_tri.dtype), check_return_val,
17371721
num_ctas=num_ctas)
1738-
1739-
if dtype_x_str == 'bfloat16':
1740-
# trunc mantissa for a fair comparison of accuracy
1741-
z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
1742-
old_ref = (old_ref.view('uint32') & np.uint32(0xffff0000)).view('float32')
1743-
# mantissa trunc is not enough, bump up the relative tolerance as well
1744-
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.5)
1745-
# check return vals, but use assert_allclose for bf16
1746-
if check_return_val:
1747-
np.testing.assert_allclose(old_ref, to_numpy(old_tri), rtol=0.5)
1748-
return
1749-
17501722
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
17511723
if check_return_val:
17521724
np.testing.assert_equal(old_ref, to_numpy(old_tri))
@@ -1756,9 +1728,8 @@ def torch_to_triton_dtype(t):
17561728
@pytest.mark.parametrize("size, num_ctas, dtype_x_str", [(size, num_ctas, dtype_x_str)
17571729
for size in [2, 4, 8, 32, 64, 128]
17581730
for num_ctas in num_ctas_list
1759-
for dtype_x_str in ['bfloat16', 'float16', 'float32']])
1731+
for dtype_x_str in ['float16', 'float32']])
17601732
def test_tensor_atomic_add_non_exclusive_offset(size, num_ctas, dtype_x_str, device):
1761-
check_type_supported(dtype_x_str, device)
17621733

17631734
@triton.jit
17641735
def kernel(X, val, NUM: tl.constexpr):
@@ -1768,9 +1739,8 @@ def kernel(X, val, NUM: tl.constexpr):
17681739
tl.atomic_add(X + offset // 2, val)
17691740

17701741
shape = (size // 2, size)
1771-
dtype = getattr(torch, dtype_x_str)
1772-
x = torch.zeros(shape, dtype=dtype, device=device)
1773-
val = torch.randn((size**2), dtype=dtype, device=device)
1742+
x = torch.zeros(shape, dtype=getattr(torch, dtype_x_str), device=device)
1743+
val = torch.randn((size**2), dtype=getattr(torch, dtype_x_str), device=device)
17741744
kernel[(1, )](x, val, size, num_warps=1, num_ctas=num_ctas)
17751745
ref = val[0::2] + val[1::2]
17761746
torch.testing.assert_close(ref, x.reshape(math.prod(shape)))
@@ -1783,7 +1753,7 @@ def kernel(X, val, NUM: tl.constexpr):
17831753
for idx_order in ['increase', 'decrease', 'random_no_duplication', 'random']
17841754
for mask_step in range(1, 5)
17851755
for num_ctas in num_ctas_list
1786-
for dtype_x_str in ['bfloat16', 'float16', 'float32']])
1756+
for dtype_x_str in ['float16', 'float32']])
17871757
def test_tensor_atomic_add_access_patterns(shape, idx_order, mask_step, num_ctas, dtype_x_str, device):
17881758
check_type_supported(dtype_x_str, device)
17891759
if is_interpreter():
@@ -1811,9 +1781,8 @@ def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.const
18111781
if idx_order == 'random':
18121782
idx = torch.randint(0, shape1, size=(shape0, shape1), device=device)
18131783

1814-
dtype = getattr(torch, dtype_x_str)
1815-
val = torch.randn((shape0, shape1), dtype=dtype, device=device)
1816-
dst = torch.randn((shape0, shape1), dtype=dtype, device=device)
1784+
val = torch.randn((shape0, shape1), dtype=getattr(torch, dtype_x_str), device=device)
1785+
dst = torch.randn((shape0, shape1), dtype=getattr(torch, dtype_x_str), device=device)
18171786

18181787
dst_ref = dst.clone()
18191788

@@ -1825,11 +1794,6 @@ def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.const
18251794
cnt += 1
18261795

18271796
kernel[(1, )](val, idx, dst, shape0, shape1, mask_step, 64, num_ctas=num_ctas)
1828-
1829-
if dtype_x_str == 'bfloat16':
1830-
torch.testing.assert_close(dst_ref, dst, rtol=0.1, atol=0.1)
1831-
return
1832-
18331797
np.testing.assert_allclose(to_numpy(dst_ref), to_numpy(dst), atol=1e-2)
18341798

18351799

0 commit comments

Comments
 (0)