Skip to content

Commit 1918084

Browse files
[TEST] float16 test for test_tensor_atomic_rmw (#4981)
This adds float16 to the list of dtypes tested in test_tensor_atomic_rmw. Note that the numerics were previously bad for this test when run in float16; this PR "fixes" the numerics by internally doing the sum in float32 (upcast, sum, downcast). Since the purpose is to test the atomic_rmw, and the numerical issues of doing sums in low-precision dtypes are generally know, I think this strategy should be fine for this test.
1 parent 15c5e55 commit 1918084

File tree

1 file changed

+26
-4
lines changed

1 file changed

+26
-4
lines changed

python/test/unit/language/test_core.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1453,17 +1453,29 @@ def kernel(X):
14531453
for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)]
14541454
for axis in [0, 1]
14551455
for num_ctas in num_ctas_list
1456-
for dtype_x_str in ['float32', 'uint64', 'int64', 'float64']])
1456+
for dtype_x_str in ['float16', 'float32', 'uint64', 'int64', 'float64']])
14571457
def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, device):
1458+
if is_interpreter() and dtype_x_str == 'float16':
1459+
pytest.skip('float16 atomic_add does not work in the interpreter mode')
14581460
shape0, shape1 = shape
14591461
# triton kernel
14601462

14611463
@triton.jit
1462-
def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
1464+
def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr, DTYPE: tl.constexpr):
14631465
off0 = tl.arange(0, SHAPE0)
14641466
off1 = tl.arange(0, SHAPE1)
14651467
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
1468+
1469+
if DTYPE == tl.float16:
1470+
# sum can have bad numerics when accumulating in float16.
1471+
# if we're dealing with float16, do the sum in float32.
1472+
x = x.to(tl.float32)
1473+
14661474
z = tl.sum(x, axis=AXIS)
1475+
1476+
if DTYPE == tl.float16:
1477+
z = z.to(DTYPE)
1478+
14671479
if AXIS == 1:
14681480
old = tl.atomic_add(Z + off0, z)
14691481
tl.store(OLD + off0, old)
@@ -1477,13 +1489,23 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const
14771489
z = numpy_random(z_shape, dtype_str=dtype_x_str, rs=rs)
14781490
old = np.zeros(z_shape, dtype=getattr(np, dtype_x_str))
14791491
# reference results
1480-
z_ref = z + np.sum(x, axis=axis, keepdims=False)
1492+
if x.dtype == np.float16:
1493+
# do the sum in float32 to reduce numerical variation
1494+
z_ref = z + np.sum(x.astype(np.float32), axis=axis, keepdims=False).astype(x.dtype)
1495+
else:
1496+
z_ref = z + np.sum(x, axis=axis, keepdims=False)
14811497
old_ref = np.copy(z)
14821498
# triton result
14831499
x_tri = to_triton(x, device=device)
14841500
z_tri = to_triton(z, device=device)
14851501
old_tri = to_triton(old, device=device)
1486-
kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, num_ctas=num_ctas)
1502+
1503+
def torch_to_triton_dtype(t):
1504+
if t == torch.float16:
1505+
return tl.float16
1506+
return None
1507+
1508+
kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, torch_to_triton_dtype(x_tri.dtype), num_ctas=num_ctas)
14871509
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
14881510
np.testing.assert_equal(old_ref, to_numpy(old_tri))
14891511

0 commit comments

Comments
 (0)