Skip to content

Commit 03155cd

Browse files
davidberard98anmyachev
authored andcommitted
[TEST] float16 test for test_tensor_atomic_rmw (#4981)
This adds float16 to the list of dtypes tested in test_tensor_atomic_rmw. Note that the numerics were previously bad for this test when run in float16; this PR "fixes" the numerics by internally doing the sum in float32 (upcast, sum, downcast). Since the purpose is to test the atomic_rmw, and the numerical issues of doing sums in low-precision dtypes are generally know, I think this strategy should be fine for this test.
1 parent 557b2cd commit 03155cd

File tree

1 file changed

+26
-4
lines changed

1 file changed

+26
-4
lines changed

python/test/unit/language/test_core.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,18 +1490,30 @@ def kernel(X):
14901490
for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)]
14911491
for axis in [0, 1]
14921492
for num_ctas in num_ctas_list
1493-
for dtype_x_str in ['float32', 'uint64', 'int64', 'float64']])
1493+
for dtype_x_str in ['float16', 'float32', 'uint64', 'int64', 'float64']])
14941494
def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, device):
14951495
check_type_supported(dtype_x_str, device)
1496+
if is_interpreter() and dtype_x_str == 'float16':
1497+
pytest.skip('float16 atomic_add does not work in the interpreter mode')
14961498
shape0, shape1 = shape
14971499
# triton kernel
14981500

14991501
@triton.jit
1500-
def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
1502+
def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr, DTYPE: tl.constexpr):
15011503
off0 = tl.arange(0, SHAPE0)
15021504
off1 = tl.arange(0, SHAPE1)
15031505
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
1506+
1507+
if DTYPE == tl.float16:
1508+
# sum can have bad numerics when accumulating in float16.
1509+
# if we're dealing with float16, do the sum in float32.
1510+
x = x.to(tl.float32)
1511+
15041512
z = tl.sum(x, axis=AXIS)
1513+
1514+
if DTYPE == tl.float16:
1515+
z = z.to(DTYPE)
1516+
15051517
if AXIS == 1:
15061518
old = tl.atomic_add(Z + off0, z)
15071519
tl.store(OLD + off0, old)
@@ -1515,13 +1527,23 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const
15151527
z = numpy_random(z_shape, dtype_str=dtype_x_str, rs=rs)
15161528
old = np.zeros(z_shape, dtype=getattr(np, dtype_x_str))
15171529
# reference results
1518-
z_ref = z + np.sum(x, axis=axis, keepdims=False)
1530+
if x.dtype == np.float16:
1531+
# do the sum in float32 to reduce numerical variation
1532+
z_ref = z + np.sum(x.astype(np.float32), axis=axis, keepdims=False).astype(x.dtype)
1533+
else:
1534+
z_ref = z + np.sum(x, axis=axis, keepdims=False)
15191535
old_ref = np.copy(z)
15201536
# triton result
15211537
x_tri = to_triton(x, device=device)
15221538
z_tri = to_triton(z, device=device)
15231539
old_tri = to_triton(old, device=device)
1524-
kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, num_ctas=num_ctas)
1540+
1541+
def torch_to_triton_dtype(t):
1542+
if t == torch.float16:
1543+
return tl.float16
1544+
return None
1545+
1546+
kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, torch_to_triton_dtype(x_tri.dtype), num_ctas=num_ctas)
15251547
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
15261548
np.testing.assert_equal(old_ref, to_numpy(old_tri))
15271549

0 commit comments

Comments
 (0)