Skip to content

Commit a52c88a

Browse files
authored
[STANDARD] Fix inf handling in tl.flip (#5447)
Fixes #5439 Currently we end up doing `0 * inf = nan`, the fix is to bitcast to int first where `x * 0 == 0` holds.
1 parent e57b468 commit a52c88a

File tree

3 files changed

+32
-5
lines changed

3 files changed

+32
-5
lines changed

python/test/unit/language/test_standard.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,29 @@ def flip_kernel(X, Z, N: tl.constexpr, M: tl.constexpr):
7575
assert (y == z).all(), (y, z)
7676

7777

78+
@pytest.mark.interpreter
79+
def test_flip_inf(device):
80+
# Reproducer for https://github.com/triton-lang/triton/issues/5439
81+
82+
@triton.jit
83+
def triton_flip_kernel(out_ptr, x_ptr, N: tl.constexpr):
84+
pid = tl.program_id(0)
85+
x = tl.load(x_ptr + pid * N + tl.arange(0, N))
86+
shape: tl.constexpr = (N // 2, 2)
87+
y = x.reshape(shape)
88+
y = tl.flip(y, dim=1).reshape(x.shape)
89+
tl.store(out_ptr + pid * N + tl.arange(0, N), y)
90+
91+
x = torch.arange(0, 16, device=device).unsqueeze(0).float()
92+
x[:, -1] = float('inf')
93+
94+
expect = x.reshape(-1, 8, 2).flip(-1).reshape(-1, 16)
95+
actual = torch.empty_like(x)
96+
triton_flip_kernel[(x.shape[0], )](actual, x, x.shape[1])
97+
98+
torch.testing.assert_close(expect, actual)
99+
100+
78101
@pytest.mark.interpreter
79102
@pytest.mark.parametrize("size_i, size_j, size_g", [[5, 7, 3]])
80103
def test_swizzle2d(size_i, size_j, size_g, device):

python/triton/language/standard.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -412,11 +412,13 @@ def flip(x, dim=None):
412412
"""
413413
core.static_assert(_is_power_of_two(x.shape[_get_flip_dim(dim, x.shape)]))
414414
core.static_assert(_is_power_of_two(x.numel))
415-
# # reshape the tensor to have all dimensions be 2.
416-
# # TODO: We shouldn't have to change the dimensions not sorted.
415+
# reshape the tensor to have all dimensions be 2.
416+
# TODO: We shouldn't have to change the dimensions not sorted.
417417
steps: core.constexpr = _log2(x.numel)
418418
start: core.constexpr = _log2(x.numel) - _log2(x.shape[_get_flip_dim(dim, x.shape)])
419-
y = core.reshape(x, [2] * steps)
419+
420+
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
421+
y = core.reshape(x.to(idtype, bitcast=True), [2] * steps)
420422
y = core.expand_dims(y, start)
421423
flip = (core.arange(0, 2)[:, None] == 1 - core.arange(0, 2))
422424
for i in core.static_range(start, steps):
@@ -425,7 +427,7 @@ def flip(x, dim=None):
425427
if j != i and j != i + 1:
426428
flip2 = core.expand_dims(flip2, j)
427429
y = sum(y * flip2, i + 1, keep_dims=True)
428-
x = core.reshape(y, x.shape)
430+
x = core.reshape(y, x.shape).to(x.dtype, bitcast=True)
429431
return x
430432

431433

python/triton/runtime/interpreter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -726,10 +726,12 @@ def check_tensor(self, input):
726726
self.check_axis(arg.shape, self.axis)
727727

728728
def to_tensor(self, ret, dtype):
729+
np_dtype = _get_np_dtype(dtype)
729730
if hasattr(ret, "shape") and ret.shape:
731+
ret = ret.astype(np_dtype)
730732
ret_type = tl.block_type(dtype, list(ret.shape))
731733
else:
732-
ret = np.array([ret]).astype(_get_np_dtype(dtype))
734+
ret = np.array([ret], dtype=np_dtype)
733735
ret_type = dtype
734736
return tl.core.tensor(TensorHandle(ret, dtype.scalar), ret_type)
735737

0 commit comments

Comments
 (0)