Skip to content

Commit 7efc694

Browse files
authored
1 parent 9487527 commit 7efc694

File tree

3 files changed

+42
-65
lines changed

3 files changed

+42
-65
lines changed

python/test/unit/language/test_standard.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,35 +27,24 @@ def test_maximum_minium(dtype, op, device):
2727

2828
@pytest.mark.interpreter
2929
@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]])
30-
@pytest.mark.parametrize("k", [None, 8])
3130
@pytest.mark.parametrize("descending", [False, True])
3231
@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16'])
33-
def test_sort(M, N, k, descending, dtype_str, device):
32+
def test_sort(M, N, descending, dtype_str, device):
3433

3534
@triton.jit
36-
def sort_kernel(X, stride_xm, Z, stride_zm, M: tl.constexpr, N: tl.constexpr, k: tl.constexpr,
37-
descending: tl.constexpr):
38-
offs_m = tl.arange(0, M)
39-
offs_x_n = tl.arange(0, N)
40-
offs_z_n = offs_x_n if k is None else tl.arange(0, k)
41-
offs_x = offs_m[:, None] * stride_xm + offs_x_n[None, :]
42-
x = tl.load(X + offs_x)
43-
if k is None:
44-
z = tl.sort(x, descending=descending)
45-
else:
46-
z = tl.topk(x, k)
47-
offs_z = offs_m[:, None] * stride_zm + offs_z_n[None, :]
48-
tl.store(Z + offs_z, z)
49-
50-
z_shape = (M, N if k is None else k)
51-
x = numpy_random((M, N), dtype_str=dtype_str)
35+
def sort_kernel(X, Z, N: tl.constexpr, M: tl.constexpr, descending: tl.constexpr):
36+
offx = tl.arange(0, M)
37+
offy = tl.arange(0, N) * M
38+
off2d = offx[None, :] + offy[:, None]
39+
x = tl.load(X + off2d)
40+
x = tl.sort(x, descending=descending)
41+
tl.store(Z + off2d, x)
42+
43+
x = numpy_random((N, M), dtype_str=dtype_str)
5244
x = torch.from_numpy(x).to(device)
53-
z = torch.empty(z_shape, dtype=x.dtype, device=x.device)
54-
if k is None:
55-
y = torch.sort(x, descending=descending)[0]
56-
else:
57-
y = torch.topk(x, k=k).values
58-
sort_kernel[(1, )](x, x.stride(0), z, z.stride(0), M, N, k, descending, num_warps=8)
45+
y = torch.sort(x, descending=descending)[0]
46+
z = torch.empty_like(x)
47+
sort_kernel[(1, )](x, z, N, M, descending, num_warps=8)
5948
assert (y == z).all(), (y, z)
6049

6150

python/triton/language/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
sort,
2020
sum,
2121
swizzle2d,
22-
topk,
2322
xor_sum,
2423
zeros,
2524
zeros_like,
@@ -253,7 +252,6 @@
253252
"sum",
254253
"swizzle2d",
255254
"tensor",
256-
"topk",
257255
"trans",
258256
"tuple",
259257
"uint16",

python/triton/language/standard.py

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
def _log2(i: core.constexpr):
1111
log2 = 0
12-
n = core.constexpr(i).value
12+
n = i.value
1313
while n > 1:
1414
n >>= 1
1515
log2 += 1
@@ -338,19 +338,20 @@ def cumprod(input, axis=0, reverse=False):
338338
def _compare_and_swap(x, flip, i: core.constexpr, n_dims: core.constexpr):
339339
n_outer: core.constexpr = x.numel >> n_dims
340340
shape: core.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)]
341-
342-
# flip along middle dimension (the bitwise XORs will be optimised away):
341+
y = core.reshape(x, shape)
342+
# slice left/right with 'stride' 2**(n_dims - i - 1)
343+
mask = core.arange(0, 2)[None, :, None]
344+
left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape).to(y.dtype)
345+
right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape).to(y.dtype)
346+
left = core.reshape(left, x.shape)
347+
right = core.reshape(right, x.shape)
348+
# actual compare-and-swap
343349
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
344-
ix = core.reshape(x, shape).to(idtype, bitcast=True)
345-
iy = ix ^ xor_sum(ix, 1, True)
346-
y = core.reshape(iy.to(x.dtype, bitcast=True), x.shape)
347-
348-
# determines whether we are in the right (rather than left) position along the axis:
349-
is_right = core.reshape(core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape)
350-
351-
# conditional swap:
352-
ret = core.where((x > y) != (flip ^ is_right), y, x)
353-
return ret
350+
ileft = left.to(idtype, bitcast=True)
351+
iright = right.to(idtype, bitcast=True)
352+
ix = x.to(idtype, bitcast=True)
353+
ret = ix ^ core.where((left > right) != flip, ileft ^ iright, zeros_like(ix))
354+
return ret.to(x.dtype, bitcast=True)
354355

355356

356357
@jit
@@ -361,14 +362,14 @@ def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core
361362
order_type 2 == alternating
362363
'''
363364
n_outer: core.constexpr = x.numel >> n_dims
365+
core.static_assert(stage <= n_dims)
364366
# flip denotes whether to re-arrange sub-sequences of elements in ascending or
365367
# descending order.
366368
# if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
367369
# if flip = 00110011... then all the elements will be re-arranged alternatingly (with
368370
# a stride of 2) at this stage
369371
if order == 2:
370-
core.static_assert(stage <= (n_dims))
371-
shape: core.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**(stage)]
372+
shape: core.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage]
372373
flip = core.reshape(core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape)
373374
else:
374375
flip = order
@@ -378,47 +379,30 @@ def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core
378379
return x
379380

380381

382+
@core._tensor_member_fn
381383
@jit
382-
def sort_impl(x, k: core.constexpr = None, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
384+
def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
383385
"""
384386
Sorts a tensor along a specified dimension.
385387
386388
:param x: The input tensor to be sorted.
387389
:type x: Tensor
388390
:param dim: The dimension along which to sort the tensor. If None, the tensor is sorted along the last dimension. Currently, only sorting along the last dimension is supported.
389391
:type dim: int, optional
390-
:param k: the number of top elements to select. If none, assume k = x.shape[dim]
391-
:type k: int, optional
392392
:param descending: If set to True, the tensor is sorted in descending order. If set to False, the tensor is sorted in ascending order.
393393
:type descending: bool, optional
394394
"""
395395
# handle default dimension or check that it is the most minor dim
396396
_dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
397397
core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
398398
# iteratively run bitonic merge-sort steps
399-
n_outer: core.constexpr = x.numel >> _log2(x.shape[_dim])
400-
log_n: core.constexpr = _log2(x.shape[_dim])
401-
log_k: core.constexpr = log_n if k is None else _log2(k)
402-
for i in core.static_range(1, log_k + 1):
403-
x = _bitonic_merge(x, i, 2 if i < log_n else descending, log_n)
404-
# select top k elements using bitonic top-k
405-
# https://www.doc.ic.ac.uk/~hlgr/pdfs/MassivelyParallelTopK.pdf
406-
for i in core.static_range(log_k + 1, log_n + 1):
407-
x = core.reshape(x, [n_outer * 2**(log_n - i), 2, 2**log_k])
408-
x = max(x, axis=1) if descending else min(x, axis=1)
409-
x = core.reshape(x, [n_outer, 2**(log_n - i + log_k)])
410-
x = _bitonic_merge(x, log_k, 2 if i < log_n else descending, _log2(x.shape[_dim]))
399+
n_dims: core.constexpr = _log2(x.shape[_dim])
400+
for i in core.static_range(1, n_dims + 1):
401+
x = _bitonic_merge(x, i, 2 if i < n_dims else descending, n_dims)
411402
return x
412403

413404

414-
@jit
415-
def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
416-
return sort_impl(x, dim=dim, descending=descending)
417-
418-
419-
@jit
420-
def topk(x, k: core.constexpr, dim: core.constexpr = None):
421-
return sort_impl(x, k=k, dim=dim, descending=True)
405+
# flip
422406

423407

424408
def _get_flip_dim(dim, shape):
@@ -450,8 +434,14 @@ def flip(x, dim=None):
450434

451435
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
452436
y = core.reshape(x.to(idtype, bitcast=True), [2] * steps)
437+
y = core.expand_dims(y, start)
438+
flip = (core.arange(0, 2)[:, None] == 1 - core.arange(0, 2))
453439
for i in core.static_range(start, steps):
454-
y = y ^ xor_sum(y, i, True)
440+
flip2 = flip
441+
for j in core.static_range(0, steps + 1):
442+
if j != i and j != i + 1:
443+
flip2 = core.expand_dims(flip2, j)
444+
y = sum(y * flip2, i + 1, keep_dims=True, dtype=y.dtype)
455445
x = core.reshape(y, x.shape).to(x.dtype, bitcast=True)
456446
return x
457447

0 commit comments

Comments
 (0)