Skip to content

Commit 1d004a9

Browse files
authored
Support sorting single-element tensors (#8040)
As per #6769, reshape([]) creates a scalar rather than a tensor. This breaks the sorting algorithm, so special case this situation.
1 parent e2ddc50 commit 1d004a9

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

python/test/unit/language/test_standard.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_maximum_minium(dtype, op, device):
2626

2727

2828
@pytest.mark.interpreter
29-
@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]])
29+
@pytest.mark.parametrize("M, N", [[1, 1], [1, 512], [8, 64], [256, 16], [512, 8]])
3030
@pytest.mark.parametrize("k", [None, 8])
3131
@pytest.mark.parametrize("descending", [False, True])
3232
@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16'])
@@ -40,7 +40,7 @@ def sort_kernel(X, stride_xm, Z, stride_zm, M: tl.constexpr, N: tl.constexpr, k:
4040
offs_z_n = offs_x_n if k is None else tl.arange(0, k)
4141
offs_x = offs_m[:, None] * stride_xm + offs_x_n[None, :]
4242
x = tl.load(X + offs_x)
43-
if k is None:
43+
if k is None or x.numel < k:
4444
z = tl.sort(x, descending=descending)
4545
else:
4646
z = tl.topk(x, k)
@@ -51,7 +51,7 @@ def sort_kernel(X, stride_xm, Z, stride_zm, M: tl.constexpr, N: tl.constexpr, k:
5151
x = numpy_random((M, N), dtype_str=dtype_str)
5252
x = torch.from_numpy(x).to(device)
5353
z = torch.empty(z_shape, dtype=x.dtype, device=x.device)
54-
if k is None:
54+
if k is None or x.numel() < k:
5555
y = torch.sort(x, descending=descending)[0]
5656
else:
5757
y = torch.topk(x, k=k).values

python/triton/language/standard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def sort_impl(x, k: core.constexpr = None, dim: core.constexpr = None, descendin
441441
n_dims: core.constexpr = _log2(x.numel)
442442

443443
# reshape to hypercube:
444-
h = core.reshape(x, [2] * n_dims)
444+
h = core.reshape(x, [2] * n_dims if n_dims else [1])
445445

446446
# run first log_k bitonic sort iterations:
447447
for i in core.static_range(1, log_k + 1):

0 commit comments

Comments
 (0)