Skip to content

Commit a5e485f

Browse files
authored
[FRONTEND] Fix atomic min/max for float with negative zero (#6431)
Fixes #6376 The software emulation of atomic min/max uses `x >= 0` to test the signbit, which breaks down when `x` is `-0.0` which equals zero but does have the sign bit set . I fix this by looking at the bit representation of the float and extracting the sign bit directly. I also fix not_ raising an error in the interpreter from `get_all_ones_value` since `np.bool` doesn't have "int" in the name.
1 parent 3eb8501 commit a5e485f

File tree

3 files changed

+34
-8
lines changed

3 files changed

+34
-8
lines changed

python/test/unit/language/test_core.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1851,6 +1851,26 @@ def kernel_r(ptrs, BLOCK_SIZE: tl.constexpr):
18511851
out = kernel_r[(2, )](data, BLOCK_SIZE=block_size, num_ctas=1, launch_cooperative_grid=False)
18521852

18531853

1854+
@pytest.mark.interpreter
1855+
def test_atomic_min_max_neg_zero(device):
1856+
1857+
@triton.jit
1858+
def kernel(inp, out_max, out_min):
1859+
idx = tl.program_id(0)
1860+
x = tl.load(inp + idx)
1861+
tl.atomic_max(out_max + idx, x)
1862+
tl.atomic_min(out_min + idx, x)
1863+
1864+
N_PROG = 1
1865+
dtype = torch.float32
1866+
out_min = torch.full([N_PROG], torch.finfo(torch.float32).max, device=device, dtype=dtype)
1867+
out_max = torch.full([N_PROG], torch.finfo(torch.float32).min, device=device, dtype=dtype)
1868+
inp = torch.full([N_PROG], -0.0, device=device, dtype=dtype)
1869+
kernel[(N_PROG, )](inp, out_max, out_min)
1870+
torch.testing.assert_close(out_min, inp, atol=0, rtol=0)
1871+
torch.testing.assert_close(out_max, inp, atol=0, rtol=0)
1872+
1873+
18541874
# ---------------
18551875
# test cast
18561876
# ---------------

python/triton/language/semantic.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,6 +1388,14 @@ def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor,
13881388
return ptr, val, mask
13891389

13901390

1391+
def _signbit(x: tl.tensor, builder: ir.builder) -> tl.tensor:
1392+
bitwidth = x.dtype.primitive_bitwidth
1393+
idtype = tl.get_int_dtype(bitwidth=bitwidth, signed=False)
1394+
ix = x.to(idtype, bitcast=True, _builder=builder)
1395+
signbit = lshr(ix, bitwidth - 1, builder)
1396+
return signbit.to(tl.int1, _builder=builder)
1397+
1398+
13911399
def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
13921400
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder)
13931401
sem = _str_to_sem(sem)
@@ -1407,16 +1415,14 @@ def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope:
14071415
if sca_ty not in {tl.float32, tl.float64}:
14081416
raise TypeError(f"atomic_max not supported for dtype {sca_ty}")
14091417

1410-
zero = full([], 0.0, sca_ty, builder)
1411-
14121418
i_type = tl.int32 if sca_ty == tl.float32 else tl.int64
14131419
i_val = bitcast(val, i_type, builder)
14141420
i_ptr = bitcast(ptr, tl.pointer_type(i_type, 1), builder)
14151421
ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64
14161422
ui_val = bitcast(val, ui_type, builder)
14171423
ui_ptr = bitcast(ptr, tl.pointer_type(ui_type, 1), builder)
1418-
pos = greater_equal(val, zero, builder)
1419-
neg = less_than(val, zero, builder)
1424+
neg = _signbit(val, builder)
1425+
pos = not_(neg, builder)
14201426
pos_ret = tl.tensor(
14211427
builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle,
14221428
and_(mask, pos, builder).handle, sem, scope), i_val.type)
@@ -1446,16 +1452,14 @@ def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope:
14461452
if sca_ty not in {tl.float32, tl.float64}:
14471453
raise TypeError(f"atomic_min not supported for dtype {sca_ty}")
14481454

1449-
zero = full([], 0.0, sca_ty, builder)
1450-
14511455
i_type = tl.int32 if sca_ty == tl.float32 else tl.int64
14521456
i_val = bitcast(val, i_type, builder)
14531457
i_ptr = bitcast(ptr, tl.pointer_type(i_type, 1), builder)
14541458
ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64
14551459
ui_val = bitcast(val, ui_type, builder)
14561460
ui_ptr = bitcast(ptr, tl.pointer_type(ui_type, 1), builder)
1457-
pos = greater_equal(val, zero, builder)
1458-
neg = less_than(val, zero, builder)
1461+
neg = _signbit(val, builder)
1462+
pos = not_(neg, builder)
14591463
pos_ret = tl.tensor(
14601464
builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle,
14611465
and_(mask, pos, builder).handle, sem, scope), i_val.type)

python/triton/runtime/interpreter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,8 @@ def get_all_ones_value(self, type):
756756
np_type = _get_np_dtype(type)
757757
if "int" in np_type.name:
758758
return TensorHandle(np.full(1, -1, dtype=np_type), type.scalar)
759+
elif np_type == np.bool_:
760+
return TensorHandle(np.full(1, True, dtype=np_type), type.scalar)
759761
else:
760762
raise TypeError(f"unsupported type {type}")
761763

0 commit comments

Comments
 (0)