Skip to content

Commit 86d8659

Browse files
Add python wrapper for gather op. (#11033)
* Add python wrapper for gather op. * Add unitest for 'rank==1' and fix comments. * Fix comments.
1 parent 28dc9ba commit 86d8659

File tree

4 files changed

+69
-4
lines changed

4 files changed

+69
-4
lines changed

doc/fluid/api/layers.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,3 +1009,9 @@ ____
10091009
.. autofunction:: paddle.fluid.layers.upsampling_bilinear2d
10101010
:noindex:
10111011

1012+
gather
1013+
____
1014+
1015+
.. autofunction:: paddle.fluid.layers.gather
1016+
:noindex:
1017+

paddle/fluid/operators/gather_op.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ class GatherOp : public framework::OperatorWithKernel {
3333
auto index_dims = ctx->GetInputDim("Index");
3434
PADDLE_ENFORCE(index_dims.size() == 1);
3535
int batch_size = ctx->GetInputDim("Index")[0];
36-
PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0");
3736
framework::DDim output_dims(ctx->GetInputDim("X"));
3837
output_dims[0] = batch_size;
3938
ctx->SetOutputDim("Out", output_dims);

python/paddle/fluid/layers/nn.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
'roi_pool',
8383
'dice_loss',
8484
'upsampling_bilinear2d',
85+
'gather',
8586
'random_crop',
8687
]
8788

@@ -3889,7 +3890,6 @@ def roi_pool(input, rois, pooled_height=1, pooled_width=1, spatial_scale=1.0):
38893890

38903891
def dice_loss(input, label, epsilon=0.00001):
38913892
"""
3892-
**Dice loss Layer**
38933893
Dice loss for comparing the similarity of two batch of data,
38943894
usually is used for binary image segmentation i.e. labels are binary.
38953895
The dice loss can be defined as below equation:
@@ -3999,6 +3999,55 @@ def _is_list_or_turple_(data):
39993999
return out
40004000

40014001

4002+
def gather(input, index):
4003+
"""
4004+
Output is obtained by gathering entries of the outer-most dimension
4005+
of X indexed by `index` and concatenate them together.
4006+
4007+
.. math::
4008+
4009+
Out = X[Index]
4010+
4011+
4012+
.. code-block:: text
4013+
4014+
4015+
Given:
4016+
4017+
X = [[1, 2],
4018+
[3, 4],
4019+
[5, 6]]
4020+
4021+
Index = [1, 2]
4022+
4023+
Then:
4024+
4025+
Out = [[3, 4],
4026+
[5, 6]]
4027+
4028+
Args:
4029+
input (Variable): The source input with rank>=1.
4030+
index (Variable): The index input with rank=1.
4031+
4032+
Returns:
4033+
output (Variable): The output is a tensor with the same rank as input.
4034+
4035+
Examples:
4036+
.. code-block:: python
4037+
4038+
output = fluid.layers.gather(x, index)
4039+
"""
4040+
helper = LayerHelper('gather', **locals())
4041+
dtype = helper.input_dtype()
4042+
out = helper.create_tmp_variable(dtype)
4043+
helper.append_op(
4044+
type="gather",
4045+
inputs={"X": input,
4046+
"Index": index},
4047+
outputs={"Out": out})
4048+
return out
4049+
4050+
40024051
def random_crop(input, shape, seed=1):
40034052
helper = LayerHelper("random_crop", **locals())
40044053
dtype = helper.input_dtype()

python/paddle/fluid/tests/unittests/test_gather_op.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
class TestGatherOp(OpTest):
2121
def setUp(self):
2222
self.op_type = "gather"
23-
xnp = np.random.random((10, 20)).astype("float32")
24-
self.inputs = {'X': xnp, 'Index': np.array([1, 3, 5]).astype("int32")}
23+
self.config()
24+
xnp = np.random.random(self.x_shape).astype("float32")
25+
self.inputs = {'X': xnp, 'Index': np.array(self.index).astype("int32")}
2526
self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
2627

2728
def test_check_output(self):
@@ -30,6 +31,16 @@ def test_check_output(self):
3031
def test_check_grad(self):
3132
self.check_grad(['X'], 'Out')
3233

34+
def config(self):
35+
self.x_shape = (10, 20)
36+
self.index = [1, 3, 5]
37+
38+
39+
class TestCase1(TestGatherOp):
40+
def config(self):
41+
self.x_shape = (10)
42+
self.index = [1, 3, 5]
43+
3344

3445
if __name__ == "__main__":
3546
unittest.main()

0 commit comments

Comments
 (0)