Skip to content

Commit 65a94be

Browse files
authored
Merge pull request #11223 from JiayiFeng/dev_reverse_op
Add reverse op
2 parents f40fc24 + ea73fb8 commit 65a94be

File tree

5 files changed

+319
-0
lines changed

5 files changed

+319
-0
lines changed

paddle/fluid/operators/reverse_op.cc

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/reverse_op.h"
16+
#include <vector>
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
class ReverseOp : public framework::OperatorWithKernel {
22+
public:
23+
using framework::OperatorWithKernel::OperatorWithKernel;
24+
25+
void InferShape(framework::InferShapeContext* ctx) const override {
26+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
27+
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null");
28+
const auto& x_dims = ctx->GetInputDim("X");
29+
const auto& axis = ctx->Attrs().Get<std::vector<int>>("axis");
30+
PADDLE_ENFORCE(!axis.empty(), "'axis' can not be empty.");
31+
for (int a : axis) {
32+
PADDLE_ENFORCE_LT(a, x_dims.size(),
33+
"The axis must be less than input tensor's rank.");
34+
}
35+
ctx->SetOutputDim("Out", x_dims);
36+
}
37+
};
38+
39+
class ReverseOpMaker : public framework::OpProtoAndCheckerMaker {
40+
public:
41+
void Make() override {
42+
AddInput("X", "The LoDTensor to be flipped.");
43+
AddOutput("Out", "The LoDTensor after flipping.");
44+
AddAttr<std::vector<int>>(
45+
"axis", "The axises that along which order of elements is reversed.");
46+
AddComment(R"DOC(
47+
Reverse Operator.
48+
49+
Reverse the order of elements in the input LoDTensor along given axises.
50+
51+
Case 1:
52+
Given
53+
X = [[1, 2, 3, 4, 5]
54+
[6, 7, 8, 9, 10]
55+
[11, 12, 13, 14, 15]],
56+
and
57+
axis = [0],
58+
we get:
59+
Out = [[11, 12, 13, 14, 15]
60+
[6, 7, 8, 9, 10]
61+
[1, 2, 3, 4, 5]].
62+
63+
Case 2:
64+
Given
65+
X = [[[1, 2, 3, 4]
66+
[5, 6, 7, 8]]
67+
[[9, 10, 11, 12]
68+
[13, 14, 15, 16]]],
69+
and
70+
axis = [0, 2],
71+
we get:
72+
Out = [[[12, 11, 10, 9]
73+
[16, 15, 14, 13]]
74+
[[4, 3, 2, 1]
75+
[8, 7, 6, 5]]],
76+
)DOC");
77+
}
78+
};
79+
80+
class ReverseGradMaker : public framework::SingleGradOpDescMaker {
81+
public:
82+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
83+
84+
std::unique_ptr<framework::OpDesc> Apply() const override {
85+
auto* grad_op = new framework::OpDesc();
86+
grad_op->SetType("reverse");
87+
grad_op->SetInput("X", OutputGrad("Out"));
88+
grad_op->SetOutput("Out", InputGrad("X"));
89+
grad_op->SetAttr("axis", GetAttr("axis"));
90+
return std::unique_ptr<framework::OpDesc>(grad_op);
91+
}
92+
};
93+
94+
} // namespace operators
95+
} // namespace paddle
96+
97+
namespace ops = paddle::operators;
98+
REGISTER_OPERATOR(reverse, ops::ReverseOp, ops::ReverseOpMaker,
99+
ops::ReverseGradMaker);
100+
REGISTER_OPERATOR(reverse_grad, ops::ReverseOp);
101+
REGISTER_OP_CPU_KERNEL(
102+
reverse, ops::ReverseKernel<paddle::platform::CPUDeviceContext, int>,
103+
ops::ReverseKernel<paddle::platform::CPUDeviceContext, uint8_t>,
104+
ops::ReverseKernel<paddle::platform::CPUDeviceContext, int64_t>,
105+
ops::ReverseKernel<paddle::platform::CPUDeviceContext, bool>,
106+
ops::ReverseKernel<paddle::platform::CPUDeviceContext, float>,
107+
ops::ReverseKernel<paddle::platform::CPUDeviceContext, double>)

paddle/fluid/operators/reverse_op.cu

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/reverse_op.h"
16+
17+
namespace ops = paddle::operators;
18+
REGISTER_OP_CUDA_KERNEL(
19+
reverse, ops::ReverseKernel<paddle::platform::CUDADeviceContext, int>,
20+
ops::ReverseKernel<paddle::platform::CUDADeviceContext, uint8_t>,
21+
ops::ReverseKernel<paddle::platform::CUDADeviceContext, int64_t>,
22+
ops::ReverseKernel<paddle::platform::CUDADeviceContext, bool>,
23+
ops::ReverseKernel<paddle::platform::CUDADeviceContext, float>,
24+
ops::ReverseKernel<paddle::platform::CUDADeviceContext, double>)

paddle/fluid/operators/reverse_op.h

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include <vector>
17+
#include "paddle/fluid/framework/eigen.h"
18+
#include "paddle/fluid/framework/op_registry.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
template <typename DeviceContext, typename T, int Rank>
23+
struct ReverseFunctor {
24+
void operator()(const DeviceContext& context, const framework::LoDTensor& in,
25+
framework::LoDTensor* out, const std::vector<int>& axis) {
26+
Eigen::array<bool, Rank> reverse_axis;
27+
for (int i = 0; i < Rank; ++i) {
28+
reverse_axis[i] = false;
29+
}
30+
for (int a : axis) {
31+
reverse_axis[a] = true;
32+
}
33+
34+
auto in_eigen = framework::EigenTensor<T, Rank>::From(in);
35+
auto out_eigen = framework::EigenTensor<T, Rank>::From(*out);
36+
auto* dev = context.eigen_device();
37+
38+
out_eigen.device(*dev) = in_eigen.reverse(reverse_axis);
39+
}
40+
};
41+
42+
template <typename DeviceContext, typename T>
43+
class ReverseKernel : public framework::OpKernel<T> {
44+
public:
45+
void Compute(const framework::ExecutionContext& context) const override {
46+
auto* x = context.Input<framework::LoDTensor>("X");
47+
auto* out = context.Output<framework::LoDTensor>("Out");
48+
out->mutable_data<T>(context.GetPlace());
49+
const auto& axis = context.Attr<std::vector<int>>("axis");
50+
int rank = x->dims().size();
51+
auto& dev_ctx = context.template device_context<DeviceContext>();
52+
53+
switch (rank) {
54+
case 1:
55+
ReverseFunctor<DeviceContext, T, 1> functor1;
56+
functor1(dev_ctx, *x, out, axis);
57+
break;
58+
case 2:
59+
ReverseFunctor<DeviceContext, T, 2> functor2;
60+
functor2(dev_ctx, *x, out, axis);
61+
break;
62+
case 3:
63+
ReverseFunctor<DeviceContext, T, 3> functor3;
64+
functor3(dev_ctx, *x, out, axis);
65+
break;
66+
case 4:
67+
ReverseFunctor<DeviceContext, T, 4> functor4;
68+
functor4(dev_ctx, *x, out, axis);
69+
break;
70+
case 5:
71+
ReverseFunctor<DeviceContext, T, 5> functor5;
72+
functor5(dev_ctx, *x, out, axis);
73+
break;
74+
case 6:
75+
ReverseFunctor<DeviceContext, T, 6> functor6;
76+
functor6(dev_ctx, *x, out, axis);
77+
break;
78+
default:
79+
PADDLE_THROW(
80+
"Reserve operator doesn't supports tensors whose ranks are greater "
81+
"than 6.");
82+
}
83+
}
84+
};
85+
86+
} // namespace operators
87+
} // namespace paddle

python/paddle/fluid/layers/tensor.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,40 @@ def zeros(shape, dtype, force_cpu=False):
363363
return fill_constant(value=0.0, **locals())
364364

365365

366+
def reverse(x, axis):
367+
"""
368+
**reverse**
369+
370+
This function reverse the input 'x' along given axises.
371+
372+
Args:
373+
x(Vairbale): the input to be reversed.
374+
axis(int|tuple|list): Axis that along which order of elements
375+
is reversed. If it is a tuple or a list, reversing
376+
will be apply on each axis in the tuple or list.
377+
378+
Returns:
379+
Variable: The reversed tensor.
380+
381+
Examples:
382+
.. code-block:: python
383+
384+
out = fluid.layers.reverse(x=in, axis=0)
385+
# or:
386+
out = fluid.layers.reverse(x=in, axis=[0,1])
387+
"""
388+
if isinstance(axis, int):
389+
axis = [axis]
390+
helper = LayerHelper("reverse", **locals())
391+
out = helper.create_tmp_variable(dtype=x.dtype)
392+
helper.append_op(
393+
type='reverse',
394+
inputs={'Input': x},
395+
outputs={'Out': [out]},
396+
attrs={'axis': axis})
397+
return out
398+
399+
366400
def save(x, file_path, overwrite=True):
367401
"""
368402
Saves a variable as a file.
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import numpy as np
17+
from op_test import OpTest
18+
19+
20+
class TestReverseOp(OpTest):
21+
def initTestCase(self):
22+
self.x = np.random.random((3, 4)).astype('float32')
23+
self.axis = [0]
24+
25+
def setUp(self):
26+
self.initTestCase()
27+
self.op_type = "reverse"
28+
self.inputs = {"X": self.x}
29+
self.attrs = {'axis': self.axis}
30+
out = self.x
31+
for a in self.axis:
32+
out = np.flip(out, axis=a)
33+
self.outputs = {'Out': out}
34+
35+
def test_check_output(self):
36+
self.check_output()
37+
38+
def test_check_grad(self):
39+
self.check_grad(['X'], 'Out')
40+
41+
42+
class TestCase0(TestReverseOp):
43+
def initTestCase(self):
44+
self.x = np.random.random((3, 4)).astype('float32')
45+
self.axis = [1]
46+
47+
48+
class TestCase1(TestReverseOp):
49+
def initTestCase(self):
50+
self.x = np.random.random((3, 4)).astype('float32')
51+
self.axis = [0, 1]
52+
53+
54+
class TestCase2(TestReverseOp):
55+
def initTestCase(self):
56+
self.x = np.random.random((3, 4, 5)).astype('float32')
57+
self.axis = [0, 2]
58+
59+
60+
class TestCase3(TestReverseOp):
61+
def initTestCase(self):
62+
self.x = np.random.random((3, 4, 5)).astype('float32')
63+
self.axis = [1, 2]
64+
65+
66+
if __name__ == '__main__':
67+
unittest.main()

0 commit comments

Comments
 (0)