Skip to content

Commit a523b6f

Browse files
author
Yibing Liu
committed
Add python api for argsort_op
1 parent 7ca511e commit a523b6f

File tree

3 files changed

+63
-6
lines changed

3 files changed

+63
-6
lines changed

doc/fluid/api/layers.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,12 @@ argmax
11051105
.. autofunction:: paddle.fluid.layers.argmax
11061106
:noindex:
11071107

1108+
argsort
1109+
------
1110+
1111+
.. autofunction:: paddle.fluid.layers.argsort
1112+
:noindex:
1113+
11081114
ones
11091115
----
11101116

paddle/fluid/operators/argsort_op.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ class ArgsortOp : public framework::OperatorWithKernel {
3434

3535
auto num_dims = in_dims.size();
3636
PADDLE_ENFORCE(axis < num_dims,
37-
"Attr(axis) %d of ArgsortOp is out of bounds for Input(X) "
38-
"dimension %d.",
37+
"Attr(axis) %d of ArgsortOp is out of bounds for Input(X)'s "
38+
"rank %d.",
39+
axis, num_dims);
40+
PADDLE_ENFORCE(axis >= -num_dims,
41+
"Attr(axis) %d of ArgsortOp must be not less than "
42+
"-rank(Input(X)) (%d).",
3943
axis, num_dims);
40-
PADDLE_ENFORCE(in_dims.size() + axis >= 0,
41-
"Attr(axis) %d of ArgsortOp plus the rank %d of Input(X) "
42-
"must be nonnegative.",
43-
axis, in_dims.size());
4444

4545
ctx->SetOutputDim("Out", in_dims);
4646
ctx->SetOutputDim("Indices", in_dims);

python/paddle/fluid/layers/tensor.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
'fill_constant',
3434
'argmin',
3535
'argmax',
36+
'argsort',
3637
'ones',
3738
'zeros',
3839
'reverse',
@@ -438,6 +439,56 @@ def argmax(x, axis=0):
438439
return out
439440

440441

442+
def argsort(input, axis=-1):
443+
"""
444+
Performs sorting on the input Variable along the given axis, and outputs
445+
sorted data Varibale and its corresponding index Variable with the same
446+
shape as :attr:`input`.
447+
448+
.. code-block:: text
449+
450+
For example, the given axis is -1 and the input Variable
451+
452+
input = [[0.15849551, 0.45865775, 0.8563702 ],
453+
[0.12070083, 0.28766365, 0.18776911]],
454+
455+
after argsort, the sorted Vairable becomes
456+
457+
out = [[0.15849551, 0.45865775, 0.8563702 ],
458+
[0.12070083, 0.18776911, 0.28766365]],
459+
460+
and the sorted indices along the given axis turn outs to be
461+
462+
indices = [[0, 1, 2],
463+
[0, 2, 1]]
464+
465+
Args:
466+
input(Variable): The input Variable for sorting.
467+
axis(int): The axis along which to sort the input Variable. When
468+
:attr:`axis` < 0, the actual axis will be :attr:`axis` +
469+
rank(:attr:`input`). Default -1, the last dimension.
470+
471+
Returns:
472+
tuple: A tuple of sorted data Variable and the sorted indices.
473+
474+
Examples:
475+
.. code-block:: python
476+
477+
input = fluid.layers.data(data=[2, 3])
478+
out, indices = fluid.layers.argsort(input, axis=0)
479+
"""
480+
helper = LayerHelper("argsort", **locals())
481+
out = helper.create_tmp_variable(dtype=input.dtype, stop_gradient=True)
482+
ids = helper.create_tmp_variable(VarDesc.VarType.INT64, stop_gradient=True)
483+
helper.append_op(
484+
type='argsort',
485+
inputs={'X': input},
486+
outputs={'Out': out,
487+
'Indics': ids},
488+
attts={'axis': axis})
489+
return out, ids
490+
491+
441492
def ones(shape, dtype, force_cpu=False):
442493
"""
443494
**ones**

0 commit comments

Comments
 (0)