Skip to content

Commit 97d47a7

Browse files
authored
Merge pull request #13913 from sneaxiy/seq_reverse
Add sequence_reverse_op
2 parents a3efba1 + 016bf51 commit 97d47a7

File tree

7 files changed

+355
-0
lines changed

7 files changed

+355
-0
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None
174174
paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None))
175175
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,))
176176
paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
177+
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
177178
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
178179
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
179180
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))

paddle/fluid/operators/math/algorithm.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,52 @@ HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) {
3939
return -1;
4040
}
4141

42+
template <typename T>
43+
HOSTDEVICE inline size_t LowerBound(const T *x, size_t num, const T &val) {
44+
#ifdef __CUDA_ARCH__
45+
// The following code is from
46+
// https://en.cppreference.com/w/cpp/algorithm/lower_bound
47+
auto *first = x;
48+
int64_t count = static_cast<int64_t>(num);
49+
while (count > 0) {
50+
int64_t step = (count >> 1);
51+
auto *it = first + step;
52+
if (*it < val) {
53+
first = ++it;
54+
count -= (step + 1);
55+
} else {
56+
count = step;
57+
}
58+
}
59+
return static_cast<size_t>(first - x);
60+
#else
61+
return static_cast<size_t>(std::lower_bound(x, x + num, val) - x);
62+
#endif
63+
}
64+
65+
template <typename T>
66+
HOSTDEVICE inline size_t UpperBound(const T *x, size_t num, const T &val) {
67+
#ifdef __CUDA_ARCH__
68+
// The following code is from
69+
// https://en.cppreference.com/w/cpp/algorithm/upper_bound
70+
auto *first = x;
71+
int64_t count = static_cast<int64_t>(num);
72+
while (count > 0) {
73+
auto step = (count >> 1);
74+
auto *it = first + step;
75+
if (val < *it) {
76+
count = step;
77+
} else {
78+
first = ++it;
79+
count -= (step + 1);
80+
}
81+
}
82+
return static_cast<size_t>(first - x);
83+
#else
84+
return static_cast<size_t>(std::upper_bound(x, x + num, val) - x);
85+
#endif
86+
}
87+
4288
} // namespace math
4389
} // namespace operators
4490
} // namespace paddle
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/sequence_reverse_op.h"
16+
17+
namespace ops = paddle::operators;
18+
19+
REGISTER_OPERATOR(sequence_reverse, ops::SequenceReverseOp,
20+
ops::SequenceReverseOpMaker,
21+
ops::SequenceReverseGradOpDescMaker);
22+
23+
REGISTER_OP_CPU_KERNEL(
24+
sequence_reverse,
25+
ops::SequenceReverseOpKernel<paddle::platform::CPUDeviceContext, uint8_t>,
26+
ops::SequenceReverseOpKernel<paddle::platform::CPUDeviceContext, int>,
27+
ops::SequenceReverseOpKernel<paddle::platform::CPUDeviceContext, int64_t>,
28+
ops::SequenceReverseOpKernel<paddle::platform::CPUDeviceContext, float>,
29+
ops::SequenceReverseOpKernel<paddle::platform::CPUDeviceContext, double>);
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/sequence_reverse_op.h"
16+
17+
namespace ops = paddle::operators;
18+
19+
REGISTER_OP_CUDA_KERNEL(
20+
sequence_reverse,
21+
ops::SequenceReverseOpKernel<paddle::platform::CUDADeviceContext, uint8_t>,
22+
ops::SequenceReverseOpKernel<paddle::platform::CUDADeviceContext, int>,
23+
ops::SequenceReverseOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
24+
ops::SequenceReverseOpKernel<paddle::platform::CUDADeviceContext, float>,
25+
ops::SequenceReverseOpKernel<paddle::platform::CUDADeviceContext, double>);
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/op_registry.h"
18+
#include "paddle/fluid/operators/math/algorithm.h"
19+
#include "paddle/fluid/platform/for_range.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
class SequenceReverseOp : public framework::OperatorWithKernel {
25+
public:
26+
using framework::OperatorWithKernel::OperatorWithKernel;
27+
28+
void InferShape(framework::InferShapeContext *ctx) const override {
29+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist");
30+
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist");
31+
32+
auto x_dim = ctx->GetInputDim("X");
33+
PADDLE_ENFORCE_GE(x_dim.size(), 2,
34+
"Rank of Input(X) must be not less than 2.");
35+
36+
ctx->SetOutputDim("Y", x_dim);
37+
ctx->ShareLoD("X", "Y");
38+
}
39+
};
40+
41+
class SequenceReverseOpMaker : public framework::OpProtoAndCheckerMaker {
42+
public:
43+
void Make() override {
44+
AddInput("X", "The input LoDTensor of sequence_reverse op.");
45+
AddOutput("Y", "The output LoDTensor of sequence_reverse op.");
46+
AddComment(R"DOC(
47+
SequenceReverse Operator.
48+
49+
Reverse each sequence in input X along dim 0.
50+
51+
Assuming X is a LoDTensor with dims [5, 4] and lod [[0, 2, 5]], where:
52+
53+
X.data() = [
54+
[1, 2, 3, 4],
55+
[5, 6, 7, 8], # the 0-th sequence with length 2
56+
[9, 10, 11, 12],
57+
[13, 14, 15, 16],
58+
[17, 18, 19, 20] # the 1-st sequence with length 3
59+
]
60+
61+
The output Y would be a LoDTensor sharing the same dims and lod with input X,
62+
and:
63+
64+
Y.data() = [
65+
[5, 6, 7, 8],
66+
[1, 2, 3, 4], # the reversed 0-th sequence with length 2
67+
[17, 18, 19, 20],
68+
[13, 14, 15, 16],
69+
[9, 10, 11, 12] # the reversed 1-st sequence with length 3
70+
]
71+
72+
This Operator is useful to build a reverse dynamic RNN network.
73+
74+
This Operator only supports one-level lod currently.
75+
)DOC");
76+
}
77+
};
78+
79+
template <typename T>
80+
struct SequenceReverseFunctor {
81+
SequenceReverseFunctor(const T *x, T *y, const size_t *lod, size_t lod_count,
82+
size_t row_numel)
83+
: x_(x), y_(y), lod_(lod), lod_count_(lod_count), row_numel_(row_numel) {}
84+
85+
HOSTDEVICE void operator()(size_t idx_x) const {
86+
auto row_idx_x = idx_x / row_numel_;
87+
auto lod_idx = math::UpperBound(lod_, lod_count_, row_idx_x);
88+
auto row_idx_y = lod_[lod_idx - 1] + (lod_[lod_idx] - 1 - row_idx_x);
89+
auto idx_y = row_idx_y * row_numel_ + idx_x % row_numel_;
90+
y_[idx_y] = x_[idx_x];
91+
}
92+
93+
const T *x_;
94+
T *y_;
95+
const size_t *lod_;
96+
size_t lod_count_;
97+
size_t row_numel_;
98+
};
99+
100+
template <typename DeviceContext, typename T>
101+
class SequenceReverseOpKernel : public framework::OpKernel<T> {
102+
using LoDTensor = framework::LoDTensor;
103+
104+
public:
105+
void Compute(const framework::ExecutionContext &ctx) const override {
106+
auto &x = *ctx.Input<LoDTensor>("X");
107+
auto *y = ctx.Output<LoDTensor>("Y");
108+
109+
PADDLE_ENFORCE_EQ(x.lod().size(), 1,
110+
"SequenceReverse Op only support one level lod.");
111+
112+
auto &dev_ctx = ctx.template device_context<DeviceContext>();
113+
const size_t *lod;
114+
size_t lod_count = x.lod()[0].size();
115+
116+
#ifdef PADDLE_WITH_CUDA
117+
if (platform::is_gpu_place(ctx.GetPlace())) {
118+
lod = x.lod()[0].CUDAData(ctx.GetPlace());
119+
} else {
120+
#endif
121+
lod = x.lod()[0].data();
122+
#ifdef PADDLE_WITH_CUDA
123+
}
124+
#endif
125+
126+
size_t limit = static_cast<size_t>(x.numel());
127+
size_t row_numel = static_cast<size_t>(limit / x.dims()[0]);
128+
auto *x_data = x.data<T>();
129+
auto *y_data = y->mutable_data<T>(ctx.GetPlace());
130+
131+
PADDLE_ENFORCE_NE(x_data, y_data,
132+
"SequenceReverse Op does not support in-place operation");
133+
134+
SequenceReverseFunctor<T> functor(x_data, y_data, lod, lod_count,
135+
row_numel);
136+
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
137+
for_range(functor);
138+
}
139+
};
140+
141+
class SequenceReverseGradOpDescMaker : public framework::SingleGradOpDescMaker {
142+
public:
143+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
144+
145+
protected:
146+
std::unique_ptr<framework::OpDesc> Apply() const override {
147+
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
148+
op->SetType("sequence_reverse");
149+
op->SetInput("X", OutputGrad("Y"));
150+
op->SetOutput("Y", InputGrad("X"));
151+
op->SetAttrMap(Attrs());
152+
return op;
153+
}
154+
};
155+
156+
} // namespace operators
157+
} // namespace paddle

python/paddle/fluid/layers/nn.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@
154154
'mul',
155155
'sigmoid_cross_entropy_with_logits',
156156
'maxout',
157+
'sequence_reverse',
157158
'affine_channel',
158159
]
159160

@@ -7484,6 +7485,33 @@ def maxout(x, groups, name=None):
74847485
return out
74857486

74867487

7488+
@templatedoc()
7489+
def sequence_reverse(x, name=None):
7490+
"""
7491+
${comment}
7492+
7493+
Args:
7494+
x(${x_type}): ${x_comment}
7495+
name(basestring|None): Name of the output.
7496+
7497+
Returns:
7498+
out(${y_type}): ${y_comment}
7499+
"""
7500+
helper = LayerHelper("sequence_reverse", **locals())
7501+
if name is None:
7502+
out = helper.create_variable_for_type_inference(dtype=x.dtype)
7503+
else:
7504+
out = helper.create_variable(
7505+
name=name, dtype=x.dtype, persistable=False)
7506+
7507+
helper.append_op(
7508+
type="sequence_reverse",
7509+
inputs={"X": x},
7510+
outputs={"Y": out},
7511+
attrs=dict())
7512+
return out
7513+
7514+
74877515
def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None):
74887516
"""
74897517
Applies a separate affine transformation to each channel of the input.
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import paddle.fluid as fluid
17+
import paddle.fluid.core as core
18+
from op_test import OpTest
19+
import numpy as np
20+
21+
22+
class TestSequenceReverseBase(OpTest):
23+
def initParameters(self):
24+
pass
25+
26+
def setUp(self):
27+
self.size = (10, 3, 4)
28+
self.lod = [2, 3, 5]
29+
self.dtype = 'float32'
30+
self.initParameters()
31+
self.op_type = 'sequence_reverse'
32+
self.x = np.random.random(self.size).astype(self.dtype)
33+
self.y = self.get_output()
34+
35+
self.inputs = {'X': (self.x, [self.lod, ]), }
36+
self.outputs = {'Y': (self.y, [self.lod, ]), }
37+
38+
def get_output(self):
39+
tmp_x = np.reshape(self.x, newshape=[self.x.shape[0], -1])
40+
tmp_y = np.ndarray(tmp_x.shape).astype(self.dtype)
41+
prev_idx = 0
42+
for cur_len in self.lod:
43+
idx_range = range(prev_idx, prev_idx + cur_len)
44+
tmp_y[idx_range, :] = np.flip(tmp_x[idx_range, :], 0)
45+
prev_idx += cur_len
46+
47+
return np.reshape(tmp_y, newshape=self.x.shape).astype(self.dtype)
48+
49+
def test_output(self):
50+
self.check_output(0)
51+
52+
def test_grad(self):
53+
self.check_grad(['X'], 'Y')
54+
55+
56+
class TestSequenceReserve1(TestSequenceReverseBase):
57+
def initParameters(self):
58+
self.size = (12, 10)
59+
self.lod = [4, 5, 3]
60+
61+
62+
class TestSequenceReverse2(TestSequenceReverseBase):
63+
def initParameters(self):
64+
self.size = (12, 10)
65+
self.lod = [12]
66+
67+
68+
if __name__ == '__main__':
69+
unittest.main()

0 commit comments

Comments
 (0)