Skip to content

Commit 579b3bc

Browse files
committed
rewrite of the torch logic
1 parent c88e93e commit 579b3bc

File tree

1 file changed

+52
-35
lines changed

1 file changed

+52
-35
lines changed

src/array_api_extra/_delegation.py

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,12 @@ def partition(
346346
The ordering of the elements in the two partitions on the either side of
347347
the k-th element in the output array is undefined.
348348
349+
Notes:
350+
If `xp` implements `partition` or an equivalent method (e.g. topk for torch),
351+
complexity will likely be O(n).
352+
If not, this function simply calls `xp.sort` and complexity is O(n log n).
353+
354+
Parameters
349355
----------
350356
a : Array
351357
Input array.
@@ -361,11 +367,6 @@ def partition(
361367
-------
362368
partitioned_array
363369
Array of the same type and shape as `a`.
364-
365-
Notes:
366-
If `xp` implements `partition` or an equivalent method (e.g. topk for torch),
367-
complexity will likely be O(n).
368-
If not, this function simply calls `xp.sort` and complexity is O(n log n).
369370
"""
370371
# Validate inputs.
371372
if xp is None:
@@ -389,26 +390,32 @@ def partition(
389390
if not (axis == -1 or axis == a.ndim - 1):
390391
a = xp.transpose(a, axis, -1)
391392

392-
# Get smallest `kth` elements along axis
393-
kth += 1 # HACK: we use a non-specified behavior of torch.topk:
394-
# in `a_left`, the element in the last position is the max
395-
a_left, indices = xp.topk(a, kth, dim=-1, largest=False, sorted=False)
393+
out = xp.empty_like(a)
394+
ranks = xp.arange(a.shape[-1]).expand_as(a)
395+
396+
split_value, indices = xp.kthvalue(a, kth + 1, keepdim=True)
397+
del indices
396398

397-
# Build a mask to remove the selected elements
398-
mask_right = xp.ones(a.shape, dtype=bool)
399-
mask_right.scatter_(dim=-1, index=indices, value=False)
399+
# fill the left-side of the partition
400+
mask_src = a < split_value
401+
n_left = mask_src.sum(dim=-1, keepdim=True)
402+
mask_dest = ranks < n_left
403+
out[mask_dest] = a[mask_src]
400404

401-
# Remaining elements along axis
402-
a_right = a[mask_right] # 1-d array
405+
# fill the middle of the partition
406+
mask_src = a == split_value
407+
n_left += mask_src.sum(dim=-1, keepdim=True)
408+
mask_dest ^= ranks < n_left
409+
out[mask_dest] = a[mask_src]
403410

404-
# Reshape. This is valid only because we work on the last axis
405-
a_right = xp.reshape(a_right, shape=(*a.shape[:-1], -1))
411+
# fill the right-side of the partition
412+
mask_src = a > split_value
413+
mask_dest = ranks >= n_left
414+
out[mask_dest] = a[mask_src]
406415

407-
# Concatenate the two parts along axis
408-
partitioned_array = xp.cat((a_left, a_right), dim=-1)
409416
if not (axis == -1 or axis == a.ndim - 1):
410-
partitioned_array = xp.transpose(partitioned_array, axis, -1)
411-
return partitioned_array
417+
out = xp.transpose(out, axis, -1)
418+
return out
412419

413420
# Note: dask topk/argtopk sort the return values, so it's
414421
# not much more efficient than sorting everything when
@@ -427,9 +434,15 @@ def argpartition(
427434
) -> Array:
428435
"""
429436
Perform an indirect partition along the given axis.
437+
430438
It returns an array of indices of the same shape as `a` that
431439
index data along the given axis in partitioned order.
432440
441+
Notes:
442+
If `xp` implements `argpartition` or an equivalent method (e.g. topk for torch),
443+
complexity will likely be O(n).
444+
If not, this function simply calls `xp.argsort` and complexity is O(n log n).
445+
433446
Parameters
434447
----------
435448
a : Array
@@ -446,11 +459,6 @@ def argpartition(
446459
-------
447460
index_array
448461
Array of indices that partition `a` along the specified axis.
449-
450-
Notes:
451-
If `xp` implements `argpartition` or an equivalent method (e.g. topk for torch),
452-
complexity will likely be O(n).
453-
If not, this function simply calls `xp.argsort` and complexity is O(n log n).
454462
"""
455463
# Validate inputs.
456464
if xp is None:
@@ -478,20 +486,29 @@ def argpartition(
478486
if not (axis == -1 or axis == a.ndim - 1):
479487
a = xp.transpose(a, axis, -1)
480488

481-
kth += 1 # HACK
482-
_, indices_left = xp.topk(a, kth, dim=-1, largest=False, sorted=False)
489+
ranks = xp.arange(a.shape[-1]).expand_as(a)
490+
out = xp.empty_like(ranks)
491+
492+
split_value, indices = xp.kthvalue(a, kth + 1, keepdim=True)
493+
del indices
494+
495+
mask_src = a < split_value
496+
n_left = mask_src.sum(dim=-1, keepdim=True)
497+
mask_dest = ranks < n_left
498+
out[mask_dest] = ranks[mask_src]
483499

484-
mask_right = xp.ones(a.shape, dtype=bool)
485-
mask_right.scatter_(dim=-1, index=indices_left, value=False)
500+
mask_src = a == split_value
501+
n_left += mask_src.sum(dim=-1, keepdim=True)
502+
mask_dest ^= ranks < n_left
503+
out[mask_dest] = ranks[mask_src]
486504

487-
indices_right = xp.nonzero(mask_right)[-1]
488-
indices_right = xp.reshape(indices_right, shape=(*a.shape[:-1], -1))
505+
mask_src = a > split_value
506+
mask_dest = ranks >= n_left
507+
out[mask_dest] = ranks[mask_src]
489508

490-
# Concatenate the two parts along axis
491-
index_array = xp.cat((indices_left, indices_right), dim=-1)
492509
if not (axis == -1 or axis == a.ndim - 1):
493-
index_array = xp.transpose(index_array, axis, -1)
494-
return index_array
510+
out = xp.transpose(out, axis, -1)
511+
return out
495512

496513
# Note: dask topk/argtopk sort the return values, so it's
497514
# not much more efficient than sorting everything when

0 commit comments

Comments
 (0)