Skip to content

Commit bea4144

Browse files
committed
Refine the implementation and add unit test.
1 parent f20617b commit bea4144

File tree

4 files changed

+196
-34
lines changed

4 files changed

+196
-34
lines changed

paddle/operators/sequence_reshape_op.cc

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,50 @@ class SequenceReshapeOp : public framework::OperatorWithKernel {
2727
"Output(Out) of SequenceReshapeOp should not be null.");
2828
auto x_dims = ctx->GetInputDim("X");
2929
PADDLE_ENFORCE_EQ(x_dims.size(), 2U, "Rank of Input(X) should be 2.");
30-
int dimension = ctx->Attrs().Get<int>("dimension");
31-
ctx->SetOutputDim("Out", {{x_dims[0], static_cast<int64_t>(dimension)}});
32-
ctx->ShareLoD("X", /*->*/ "Out");
30+
int dimension = ctx->Attrs().Get<int>("new_dim");
31+
ctx->SetOutputDim("Out", {x_dims[0], static_cast<int64_t>(dimension)});
3332
}
3433
};
3534

3635
class SequenceReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
3736
public:
3837
SequenceReshapeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
3938
: OpProtoAndCheckerMaker(proto, op_checker) {
40-
AddInput("X", "");
41-
AddOutput("Out", "");
42-
AddAttr<int>("dimension", "");
43-
AddAttr<bool>("is_padding", "Default padding zero.");
44-
AddComment(R"DOC()DOC");
39+
AddInput("X",
40+
"(LoDTensor, default LoDTensor<float>) A 2-D LoDTensor with shape "
41+
"being [N, M].");
42+
AddOutput("Out",
43+
"(LoDTensor, default LoDTensor<float>) A 2-D LoDTensor with "
44+
"shape [T, new_dim] where T is calculated based on X.lod, M and "
45+
"new_dim.");
46+
AddAttr<int>("new_dim", "Sequence dimension of the output LoDTensor.");
47+
AddComment(R"DOC(
48+
Sequence Reshape Operator.
49+
50+
This operator will rearrange the input sequences. The new dimension is set by
51+
attribute and length of each sequence may change longer or shorter which is
52+
decided by original length, original dimension and new dimension. The following
53+
example will help to illustrate the function of this operator:
54+
55+
x is a LoDTensor:
56+
x.lod = [[0, 2, 6]]
57+
x.data = [[0.1, 0.2], [0.3, 0.4],
58+
[0.5, 0.6], [0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]
59+
x.dims = [6, 2]
60+
61+
set new_dim = 4
62+
63+
then out is a LoDTensor:
64+
out.lod = [[0, 1, 3]]
65+
out.data = [[0.1, 0.2, 0.3, 0.4],
66+
[0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]]
67+
out.dims = [3, 4]
68+
69+
Currently, only 1-level LoDTensor is supported and please make sure (original
70+
length * original dimension) can be divided by new_dim with no remainder for
71+
each sequence.
72+
73+
)DOC");
4574
}
4675
};
4776

@@ -63,12 +92,29 @@ class SequenceReshapeGradOp : public framework::OperatorWithKernel {
6392
}
6493
};
6594

95+
class SequenceReshapeGradOpMaker : public framework::SingleGradOpDescMaker {
96+
public:
97+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
98+
99+
protected:
100+
std::unique_ptr<framework::OpDesc> Apply() const override {
101+
auto* op_desc_ptr = new framework::OpDesc();
102+
op_desc_ptr->SetType("sequence_reshape_grad");
103+
op_desc_ptr->SetInput("X", Input("X"));
104+
op_desc_ptr->SetInput("Out", Output("Out"));
105+
op_desc_ptr->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
106+
op_desc_ptr->SetOutput(framework::GradVarName("X"), InputGrad("X"));
107+
op_desc_ptr->SetAttrMap(Attrs());
108+
return std::unique_ptr<framework::OpDesc>(op_desc_ptr);
109+
}
110+
};
111+
66112
} // namespace operators
67113
} // namespace paddle
68114

69115
namespace ops = paddle::operators;
70116
REGISTER_OPERATOR(sequence_reshape, ops::SequenceReshapeOp,
71-
ops::SequenceReshapeOpMaker);
117+
ops::SequenceReshapeOpMaker, ops::SequenceReshapeGradOpMaker);
72118
REGISTER_OPERATOR(sequence_reshape_grad, ops::SequenceReshapeGradOp);
73119
REGISTER_OP_CPU_KERNEL(
74120
sequence_reshape,
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/sequence_reshape_op.h"
16+
17+
namespace ops = paddle::operators;
18+
REGISTER_OP_CUDA_KERNEL(
19+
sequence_reshape,
20+
ops::SequenceReshapeKernel<paddle::platform::CUDADeviceContext, float>);
21+
REGISTER_OP_CUDA_KERNEL(
22+
sequence_reshape_grad,
23+
ops::SequenceReshapeGradKernel<paddle::platform::CUDADeviceContext, float>);

paddle/operators/sequence_reshape_op.h

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,53 +26,63 @@ class SequenceReshapeKernel : public framework::OpKernel<T> {
2626
void Compute(const framework::ExecutionContext& context) const override {
2727
auto* in = context.Input<LoDTensor>("X");
2828
auto* out = context.Output<LoDTensor>("Out");
29-
int out_width = context.Attr<int>("dimension");
30-
bool whether_padding = context.Attr<bool>("whether_padding");
29+
int out_width = context.Attr<int>("new_dim");
3130

3231
const T* p_in_data = in->data<T>();
33-
T* p_out_data = out->mutable_data<T>(context.GetPlace());
3432

35-
// compute shape for output
3633
auto in_dims = in->dims();
3734
int64_t in_width = in_dims[1];
3835
auto& in_lod = in->lod();
3936

4037
PADDLE_ENFORCE_EQ(in_lod.size(), 1UL,
4138
"Only support one level sequence now.");
42-
PADDLE_ENFORCE_GE(
43-
in_dims[0],
44-
/* batch size = */ static_cast<int64_t>(in_lod[0].size() - 1),
45-
"The 1st dimension of Input(X) must be equal or larger than batch "
46-
"size.");
39+
PADDLE_ENFORCE_EQ(
40+
in_dims[0], in_lod[0].back(),
41+
"Inconsistent size between X.shape[0] and X.lod()[0].back().");
4742

4843
auto in_lod_l0 = in_lod[0];
4944
int seq_num = in_lod_l0.size() - 1;
5045

5146
auto& out_lod = *out->mutable_lod();
52-
out_lod.push_back(std::vector<size_t>({0}));
53-
size_t offset = 0;
47+
out_lod.resize(1);
48+
out_lod[0].clear();
49+
out_lod[0].push_back(0);
5450
for (int i = 0; i < seq_num; ++i) {
5551
size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i];
56-
if (whether_padding) {
57-
offset += std::ceil((float)(seq_len * in_width) / out_width);
58-
} else {
59-
offset += (seq_len * in_width) / out_width;
60-
}
61-
out_lod[0].push_back(offset);
52+
size_t offset = 0;
53+
offset = (seq_len * in_width) / out_width;
54+
PADDLE_ENFORCE_EQ(offset * out_width, seq_len * in_width,
55+
"Please make sure (sequence_length * dimension) can be "
56+
"divided by new_dim with no remainder for each "
57+
"sequence. The %dth sequence is invalid.",
58+
i + 1);
59+
PADDLE_ENFORCE_GT(offset, 0,
60+
"Illegal operation, length of the %dth sequence become "
61+
"to 0 after reshaped.",
62+
i + 1);
63+
out_lod[0].push_back(out_lod[0].back() + offset);
6264
}
6365

64-
out->Resize({{static_cast<int64_t>(out_lod[0].back()), out_width}});
66+
out->mutable_data<T>(context.GetPlace());
67+
out->Resize({static_cast<int64_t>(out_lod[0].back()), out_width});
68+
T* p_out_data = out->mutable_data<T>(context.GetPlace());
6569
math::set_constant(context.device_context(), out, 0.0f);
6670

6771
for (int i = 0; i < seq_num; ++i) {
6872
size_t in_offset = in_lod_l0[i] * in_width;
6973
size_t out_offset = out_lod[0][i] * out_width;
70-
size_t bytes = sizeof(T) * (in_lod_l0[i + 1] - in_lod_l0[i]) * in_width;
74+
size_t in_count = (in_lod_l0[i + 1] - in_lod_l0[i]) * in_width;
75+
size_t out_count = (out_lod[0][i + 1] - out_lod[0][i]) * out_width;
76+
size_t bytes = sizeof(T) * std::min(in_count, out_count);
7177
if (platform::is_cpu_place(context.GetPlace())) {
72-
std::memcpy(p_out_data + out_offset, p_in_data + in_offset, bytes);
78+
memory::Copy(boost::get<platform::CPUPlace>(context.GetPlace()),
79+
p_out_data + out_offset,
80+
boost::get<platform::CPUPlace>(context.GetPlace()),
81+
p_in_data + in_offset, bytes);
7382
} else {
7483
#ifdef PADDLE_WITH_CUDA
75-
auto& dev_ctx = context.template device_context<DeviceContext>();
84+
auto& dev_ctx =
85+
context.template device_context<platform::CUDADeviceContext>();
7686
memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
7787
p_out_data + out_offset,
7888
boost::get<platform::CUDAPlace>(context.GetPlace()),
@@ -103,16 +113,23 @@ class SequenceReshapeGradKernel : public framework::OpKernel<T> {
103113
auto& out_lod = out_tensor_ptr->lod();
104114
int out_width = out_tensor_ptr->dims()[1];
105115

116+
math::set_constant(context.device_context(), x_grad_tensor_ptr, 0.0f);
117+
106118
for (int i = 0; i < seq_num; ++i) {
107119
size_t src_offset = out_lod[0][i] * out_width;
108120
size_t dst_offset = x_lod[0][i] * x_width;
109-
size_t bytes = sizeof(T) * (x_lod[0][i + 1] - x_lod[0][i]) * x_width;
121+
size_t src_count = (out_lod[0][i + 1] - out_lod[0][i]) * out_width;
122+
size_t dst_count = (x_lod[0][i + 1] - x_lod[0][i]) * x_width;
123+
size_t bytes = sizeof(T) * std::min(src_count, dst_count);
110124
if (platform::is_cpu_place(context.GetPlace())) {
111-
std::memcpy(p_x_grad_data + dst_offset, p_out_grad_data + src_offset,
112-
bytes);
125+
memory::Copy(boost::get<platform::CPUPlace>(context.GetPlace()),
126+
p_x_grad_data + dst_offset,
127+
boost::get<platform::CPUPlace>(context.GetPlace()),
128+
p_out_grad_data + src_offset, bytes);
113129
} else {
114130
#ifdef PADDLE_WITH_CUDA
115-
auto& dev_ctx = context.template device_context<DeviceContext>();
131+
auto& dev_ctx =
132+
context.template device_context<platform::CUDADeviceContext>();
116133
memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
117134
p_x_grad_data + dst_offset,
118135
boost::get<platform::CUDAPlace>(context.GetPlace()),
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
#Licensed under the Apache License, Version 2.0 (the "License");
4+
#you may not use this file except in compliance with the License.
5+
#You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
#Unless required by applicable law or agreed to in writing, software
10+
#distributed under the License is distributed on an "AS IS" BASIS,
11+
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
#See the License for the specific language governing permissions and
13+
#limitations under the License.
14+
import unittest
15+
import numpy as np
16+
import math
17+
from op_test import OpTest
18+
19+
20+
class TestSequenceReshape(OpTest):
21+
def setUp(self):
22+
self.op_type = 'sequence_reshape'
23+
dimension = 12
24+
x_lod = [[0, 4, 5, 8, 11]]
25+
x = np.random.uniform(0.1, 1, [11, 24]).astype('float32')
26+
27+
self.inputs = {'X': (x, x_lod)}
28+
self.attrs = {'new_dim': dimension}
29+
30+
out, out_lod = self.compute_output(x, x_lod, dimension)
31+
32+
self.outputs = {'Out': (out, out_lod)}
33+
34+
def compute_output(self, x, x_lod, dimension):
35+
x_width = x.shape[1]
36+
out_lod = [[0]]
37+
for i in xrange(len(x_lod[0]) - 1):
38+
seq_len = x_lod[0][i + 1] - x_lod[0][i]
39+
offset = (seq_len * x_width) / dimension
40+
assert int(offset) * dimension == seq_len * x_width
41+
out_lod[0].append(out_lod[0][-1] + int(offset))
42+
out = np.zeros(shape=(out_lod[0][-1], dimension)).astype('float32')
43+
for i in xrange(len(x_lod[0]) - 1):
44+
x_offset = x_lod[0][i] * x_width
45+
out_offset = out_lod[0][i] * dimension
46+
out_count = (out_lod[0][i + 1] - out_lod[0][i]) * dimension
47+
x_count = (x_lod[0][i + 1] - x_lod[0][i]) * x_width
48+
count = min(out_count, x_count)
49+
out.ravel()[out_offset:out_offset + count] = x.ravel()[
50+
x_offset:x_offset + count]
51+
return out, out_lod
52+
53+
def test_check_output(self):
54+
self.check_output()
55+
56+
def test_check_grad(self):
57+
self.check_grad(["X"], "Out")
58+
59+
60+
class TestSequenceReshape_reduce(TestSequenceReshape):
61+
def setUp(self):
62+
self.op_type = 'sequence_reshape'
63+
dimension = 24
64+
x_lod = [[0, 4, 6, 8, 12]]
65+
x = np.random.uniform(0.1, 1, [12, 12]).astype('float32')
66+
67+
self.inputs = {'X': (x, x_lod)}
68+
self.attrs = {'new_dim': dimension}
69+
70+
out, out_lod = self.compute_output(x, x_lod, dimension)
71+
72+
self.outputs = {'Out': (out, out_lod)}
73+
74+
75+
if __name__ == '__main__':
76+
unittest.main()

0 commit comments

Comments
 (0)