Skip to content

Commit 693083a

Browse files
authored
add index sample (#25260)
* test=release/1.8, add index sample
1 parent 0fff183 commit 693083a

File tree

2 files changed

+110
-0
lines changed

2 files changed

+110
-0
lines changed

python/paddle/fluid/layers/nn.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@
212212
'flip',
213213
'roll',
214214
'log_softmax',
215+
'index_sample',
215216
]
216217

217218

@@ -16555,3 +16556,84 @@ def log_softmax(input, axis=None, dtype=None, name=None):
1655516556
type='log', inputs={'X': outs_softmax}, outputs={'Out': outs_log})
1655616557

1655716558
return outs_log
16559+
16560+
16561+
def index_sample(x, index):
16562+
"""
16563+
**IndexSample Layer**
16564+
IndexSample OP returns the element of the specified location of X,
16565+
and the location is specified by Index.
16566+
16567+
.. code-block:: text
16568+
16569+
Args:
16570+
x (Variable): The source input tensor with 2-D shape. Supported data type is
16571+
int32, int64, float32, float64.
16572+
index (Variable): The index input tensor with 2-D shape, first dimension should be same with X.
16573+
Data type is int32 or int64.
16574+
16575+
Returns:
16576+
Variable: A tensor with the same shape as `index` .
16577+
16578+
Examples:
16579+
.. code-block:: python
16580+
16581+
import paddle.fluid as fluid
16582+
import numpy as np
16583+
16584+
data = np.array([[1.0, 2.0, 3.0, 4.0],
16585+
[5.0, 6.0, 7.0, 8.0],
16586+
[9.0, 10.0, 11.0, 12.0]]).astype('float32')
16587+
16588+
data_index = np.array([[0, 1, 2],
16589+
[1, 2, 3],
16590+
[0, 0, 0]]).astype('int32')
16591+
16592+
target_data = np.array([[100, 200, 300, 400],
16593+
[500, 600, 700, 800],
16594+
[900, 1000, 1100, 1200]]).astype('int32')
16595+
16596+
with fluid.dygraph.guard():
16597+
x = fluid.dygraph.to_variable(data)
16598+
index = fluid.dygraph.to_variable(data_index)
16599+
target = fluid.dygraph.to_variable(target_data)
16600+
16601+
out_z1 = fluid.layers.index_sample(x, index)
16602+
print(out_z1.numpy())
16603+
#[[1. 2. 3.]
16604+
# [6. 7. 8.]
16605+
# [9. 9. 9.]]
16606+
16607+
# Use the index of the maximum value by topk op
16608+
# get the value of the element of the corresponding index in other tensors
16609+
top_value, top_index = fluid.layers.topk(x, k=2)
16610+
out_z2 = fluid.layers.index_sample(target, top_index)
16611+
print(top_value.numpy())
16612+
#[[ 4. 3.]
16613+
# [ 8. 7.]
16614+
# [12. 11.]]
16615+
16616+
print(top_index.numpy())
16617+
#[[3 2]
16618+
# [3 2]
16619+
# [3 2]]
16620+
16621+
print(out_z2.numpy())
16622+
#[[ 400 300]
16623+
# [ 800 700]
16624+
# [1200 1100]]
16625+
"""
16626+
helper = LayerHelper("index_sample", **locals())
16627+
16628+
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
16629+
'fluid.layers.index_sample')
16630+
check_variable_and_dtype(index, 'index', ['int32', 'int64'],
16631+
'fluid.layers.index_sample')
16632+
out = helper.create_variable_for_type_inference(dtype=x.dtype)
16633+
16634+
helper.append_op(
16635+
type='index_sample',
16636+
inputs={'X': x,
16637+
'Index': index},
16638+
outputs={'Out': out})
16639+
return out

python/paddle/fluid/tests/unittests/test_index_sample_op.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,5 +96,33 @@ def config(self):
9696
self.index_type = "int64"
9797

9898

99+
class TestIndexSampleShape(unittest.TestCase):
100+
def test_shape(self):
101+
import paddle.fluid as fluid
102+
import paddle
103+
104+
# create x value
105+
x_shape = (2, 5)
106+
x_type = "float64"
107+
x_np = np.random.random(x_shape).astype(x_type)
108+
109+
# create index value
110+
index_shape = (2, 3)
111+
index_type = "int32"
112+
index_np = np.random.randint(
113+
low=0, high=x_shape[1], size=index_shape).astype(index_type)
114+
115+
x = fluid.data(name='x', shape=[-1, 5], dtype='float64')
116+
index = fluid.data(name='index', shape=[-1, 3], dtype='int32')
117+
output = fluid.layers.index_sample(x=x, index=index)
118+
119+
place = fluid.CPUPlace()
120+
exe = fluid.Executor(place=place)
121+
exe.run(fluid.default_startup_program())
122+
123+
feed = {'x': x_np, 'index': index_np}
124+
res = exe.run(feed=feed, fetch_list=[output])
125+
126+
99127
if __name__ == "__main__":
100128
unittest.main()

0 commit comments

Comments
 (0)