Skip to content

Commit 191ece3

Browse files
authored
Use xor-swap trick to simplify tl.sort and tl.flip (triton-lang#6486)
This improves the runtime of an internal radix sort benchmark by 25%
1 parent 31b2b23 commit 191ece3

File tree

1 file changed

+13
-20
lines changed

1 file changed

+13
-20
lines changed

python/triton/language/standard.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -338,20 +338,19 @@ def cumprod(input, axis=0, reverse=False):
338338
def _compare_and_swap(x, flip, i: core.constexpr, n_dims: core.constexpr):
339339
n_outer: core.constexpr = x.numel >> n_dims
340340
shape: core.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)]
341-
y = core.reshape(x, shape)
342-
# slice left/right with 'stride' 2**(n_dims - i - 1)
343-
left, right = core.split(core.permute(y, (0, 2, 1)))
344-
left = core.reshape(core.broadcast_to(left[:, None, :], shape), x.shape)
345-
right = core.reshape(core.broadcast_to(right[:, None, :], shape), x.shape)
346-
left = left.to(y.dtype)
347-
right = right.to(y.dtype)
348-
# actual compare-and-swap
341+
342+
# flip along middle dimension (the bitwise XORs will be optimised away):
349343
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
350-
ileft = left.to(idtype, bitcast=True)
351-
iright = right.to(idtype, bitcast=True)
352-
ix = x.to(idtype, bitcast=True)
353-
ret = ix ^ core.where((left > right) != flip, ileft ^ iright, zeros_like(ix))
354-
return ret.to(x.dtype, bitcast=True)
344+
ix = core.reshape(x, shape).to(idtype, bitcast=True)
345+
iy = ix ^ xor_sum(ix, 1, True)
346+
y = core.reshape(iy.to(x.dtype, bitcast=True), x.shape)
347+
348+
# determines whether we are in the right (rather than left) position along the axis:
349+
is_right = core.reshape(core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape)
350+
351+
# conditional swap:
352+
ret = core.where((x > y) != (flip ^ is_right), y, x)
353+
return ret
355354

356355

357356
@jit
@@ -451,14 +450,8 @@ def flip(x, dim=None):
451450

452451
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
453452
y = core.reshape(x.to(idtype, bitcast=True), [2] * steps)
454-
y = core.expand_dims(y, start)
455-
flip = (core.arange(0, 2)[:, None] == 1 - core.arange(0, 2))
456453
for i in core.static_range(start, steps):
457-
flip2 = flip
458-
for j in core.static_range(0, steps + 1):
459-
if j != i and j != i + 1:
460-
flip2 = core.expand_dims(flip2, j)
461-
y = sum(y * flip2, i + 1, keep_dims=True, dtype=y.dtype)
454+
y = y ^ xor_sum(y, i, True)
462455
x = core.reshape(y, x.shape).to(x.dtype, bitcast=True)
463456
return x
464457

0 commit comments

Comments
 (0)