Skip to content

Commit 4f93331

Browse files
authored
Merge pull request #7662 from pkuyym/fix-6678
Add sequence reshape operator
2 parents 9fea1d4 + b07ca1d commit 4f93331

File tree

4 files changed

+330
-0
lines changed

4 files changed

+330
-0
lines changed
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/sequence_reshape_op.h"
16+
#include "paddle/framework/ddim.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
class SequenceReshapeOp : public framework::OperatorWithKernel {
22+
public:
23+
using framework::OperatorWithKernel::OperatorWithKernel;
24+
void InferShape(framework::InferShapeContext* ctx) const override {
25+
PADDLE_ENFORCE(ctx->HasInput("X"),
26+
"Input(X) of SequenceReshapeOp should not be null.");
27+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
28+
"Output(Out) of SequenceReshapeOp should not be null.");
29+
auto x_dims = ctx->GetInputDim("X");
30+
auto x_numel = product(x_dims);
31+
PADDLE_ENFORCE_EQ(x_dims.size(), 2U, "Rank of Input(X) should be 2.");
32+
int new_dim = ctx->Attrs().Get<int>("new_dim");
33+
ctx->SetOutputDim("Out",
34+
{x_numel / new_dim, static_cast<int64_t>(new_dim)});
35+
}
36+
};
37+
38+
class SequenceReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
39+
public:
40+
SequenceReshapeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
41+
: OpProtoAndCheckerMaker(proto, op_checker) {
42+
AddInput("X",
43+
"(LoDTensor, default LoDTensor<float>) A 2-D LoDTensor with shape "
44+
"being [N, M].");
45+
AddOutput("Out",
46+
"(LoDTensor, default LoDTensor<float>) A 2-D LoDTensor with "
47+
"shape [T, new_dim] where T is calculated based on X.lod, M and "
48+
"new_dim.");
49+
AddAttr<int>("new_dim", "Sequence dimension of the output LoDTensor.");
50+
AddComment(R"DOC(
51+
Sequence Reshape Operator.
52+
53+
This operator will rearrange the input sequences. The new dimension is set by
54+
attribute and length of each sequence may change longer or shorter which is
55+
decided by original length, original dimension and new dimension. The following
56+
example will help to illustrate the function of this operator:
57+
58+
x is a LoDTensor:
59+
x.lod = [[0, 2, 6]]
60+
x.data = [[1, 2], [3, 4],
61+
[5, 6], [7, 8], [9, 10], [11, 12]]
62+
x.dims = [6, 2]
63+
64+
set new_dim = 4
65+
66+
then out is a LoDTensor:
67+
out.lod = [[0, 1, 3]]
68+
out.data = [[1, 2, 3, 4],
69+
[5, 6, 7, 8], [9, 10, 11, 12]]
70+
out.dims = [3, 4]
71+
72+
Currently, only 1-level LoDTensor is supported and please make sure (original
73+
length * original dimension) can be divided by new_dim with no remainder for
74+
each sequence.
75+
76+
)DOC");
77+
}
78+
};
79+
80+
class SequenceReshapeGradOp : public framework::OperatorWithKernel {
81+
public:
82+
using framework::OperatorWithKernel::OperatorWithKernel;
83+
84+
void InferShape(framework::InferShapeContext* ctx) const override {
85+
PADDLE_ENFORCE(
86+
ctx->HasInput(framework::GradVarName("Out")),
87+
"Input(Out@GRAD) of SequenceReshapeGradOp should not be null.");
88+
PADDLE_ENFORCE(ctx->HasInput("X"),
89+
"Input(X) of SequenceReshapeGradOp should not be null.");
90+
91+
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
92+
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
93+
}
94+
};
95+
96+
class SequenceReshapeGradOpMaker : public framework::SingleGradOpDescMaker {
97+
public:
98+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
99+
100+
protected:
101+
std::unique_ptr<framework::OpDesc> Apply() const override {
102+
auto* op_desc_ptr = new framework::OpDesc();
103+
op_desc_ptr->SetType("sequence_reshape_grad");
104+
op_desc_ptr->SetInput("X", Input("X"));
105+
op_desc_ptr->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
106+
op_desc_ptr->SetOutput(framework::GradVarName("X"), InputGrad("X"));
107+
op_desc_ptr->SetAttrMap(Attrs());
108+
return std::unique_ptr<framework::OpDesc>(op_desc_ptr);
109+
}
110+
};
111+
112+
} // namespace operators
113+
} // namespace paddle
114+
115+
namespace ops = paddle::operators;
116+
REGISTER_OPERATOR(sequence_reshape, ops::SequenceReshapeOp,
117+
ops::SequenceReshapeOpMaker, ops::SequenceReshapeGradOpMaker);
118+
REGISTER_OPERATOR(sequence_reshape_grad, ops::SequenceReshapeGradOp);
119+
REGISTER_OP_CPU_KERNEL(
120+
sequence_reshape,
121+
ops::SequenceReshapeKernel<paddle::platform::CPUDeviceContext, float>,
122+
ops::SequenceReshapeKernel<paddle::platform::CPUDeviceContext, double>,
123+
ops::SequenceReshapeKernel<paddle::platform::CPUDeviceContext, int>,
124+
ops::SequenceReshapeKernel<paddle::platform::CPUDeviceContext, int64_t>);
125+
REGISTER_OP_CPU_KERNEL(
126+
sequence_reshape_grad,
127+
ops::SequenceReshapeGradKernel<paddle::platform::CPUDeviceContext, float>,
128+
ops::SequenceReshapeGradKernel<paddle::platform::CPUDeviceContext, double>,
129+
ops::SequenceReshapeGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
130+
ops::SequenceReshapeGradKernel<paddle::platform::CPUDeviceContext, int>);
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/sequence_reshape_op.h"
16+
17+
namespace ops = paddle::operators;
18+
REGISTER_OP_CUDA_KERNEL(
19+
sequence_reshape,
20+
ops::SequenceReshapeKernel<paddle::platform::CUDADeviceContext, float>,
21+
ops::SequenceReshapeKernel<paddle::platform::CUDADeviceContext, double>,
22+
ops::SequenceReshapeKernel<paddle::platform::CUDADeviceContext, int>,
23+
ops::SequenceReshapeKernel<paddle::platform::CUDADeviceContext, int64_t>);
24+
REGISTER_OP_CUDA_KERNEL(
25+
sequence_reshape_grad,
26+
ops::SequenceReshapeGradKernel<paddle::platform::CUDADeviceContext, float>,
27+
ops::SequenceReshapeGradKernel<paddle::platform::CUDADeviceContext, double>,
28+
ops::SequenceReshapeGradKernel<paddle::platform::CUDADeviceContext,
29+
int64_t>,
30+
ops::SequenceReshapeGradKernel<paddle::platform::CUDADeviceContext, int>);
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/framework/op_registry.h"
17+
#include "paddle/operators/math/math_function.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
using LoDTensor = framework::LoDTensor;
23+
template <typename DeviceContext, typename T>
24+
class SequenceReshapeKernel : public framework::OpKernel<T> {
25+
public:
26+
void Compute(const framework::ExecutionContext& context) const override {
27+
auto* in = context.Input<LoDTensor>("X");
28+
auto* out = context.Output<LoDTensor>("Out");
29+
int out_width = context.Attr<int>("new_dim");
30+
31+
auto in_dims = in->dims();
32+
int64_t in_width = in_dims[1];
33+
auto& in_lod = in->lod();
34+
35+
PADDLE_ENFORCE_EQ(in_lod.size(), 1UL,
36+
"Only support one level sequence now.");
37+
PADDLE_ENFORCE_EQ(
38+
in_dims[0], in_lod[0].back(),
39+
"Inconsistent size between X.shape[0] and X.lod()[0].back().");
40+
41+
auto in_lod_l0 = in_lod[0];
42+
int seq_num = in_lod_l0.size() - 1;
43+
44+
if (in_width == out_width) {
45+
out->set_lod(in->lod());
46+
} else {
47+
auto& out_lod = *out->mutable_lod();
48+
out_lod.resize(1);
49+
out_lod[0].resize(seq_num + 1);
50+
out_lod[0][0] = 0;
51+
for (int i = 0; i < seq_num; ++i) {
52+
size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i];
53+
size_t offset = 0;
54+
offset = (seq_len * in_width) / out_width;
55+
PADDLE_ENFORCE_EQ(offset * out_width, seq_len * in_width,
56+
"Please make sure (sequence_length * dimension) can "
57+
"be divided by new_dim with no remainder for each "
58+
"sequence. The %dth sequence is invalid.",
59+
i + 1);
60+
out_lod[0][i + 1] = out_lod[0][i] + offset;
61+
}
62+
}
63+
64+
framework::Copy(*in, context.GetPlace(), out);
65+
out->Resize({static_cast<int64_t>(out->lod()[0].back()), out_width});
66+
}
67+
};
68+
69+
template <typename DeviceContext, typename T>
70+
class SequenceReshapeGradKernel : public framework::OpKernel<T> {
71+
public:
72+
void Compute(const framework::ExecutionContext& context) const override {
73+
auto* x_tensor_ptr = context.Input<LoDTensor>("X");
74+
auto* outg_tensor_ptr =
75+
context.Input<LoDTensor>(framework::GradVarName("Out"));
76+
auto* xg_tensor_ptr =
77+
context.Output<LoDTensor>(framework::GradVarName("X"));
78+
79+
xg_tensor_ptr->mutable_data<T>(context.GetPlace());
80+
framework::Copy(*outg_tensor_ptr, context.GetPlace(), xg_tensor_ptr);
81+
xg_tensor_ptr->Resize(x_tensor_ptr->dims());
82+
}
83+
};
84+
85+
} // namespace operators
86+
} // namespace paddle
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
#Licensed under the Apache License, Version 2.0 (the "License");
4+
#you may not use this file except in compliance with the License.
5+
#You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
#Unless required by applicable law or agreed to in writing, software
10+
#distributed under the License is distributed on an "AS IS" BASIS,
11+
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
#See the License for the specific language governing permissions and
13+
#limitations under the License.
14+
import unittest
15+
import numpy as np
16+
import math
17+
from op_test import OpTest
18+
19+
20+
class TestSequenceReshape(OpTest):
21+
def setUp(self):
22+
self.op_type = 'sequence_reshape'
23+
dimension = 12
24+
x_lod = [[0, 4, 5, 8, 11]]
25+
x = np.random.uniform(0.1, 1, [11, 24]).astype('float32')
26+
27+
self.inputs = {'X': (x, x_lod)}
28+
self.attrs = {'new_dim': dimension}
29+
30+
out, out_lod = self.compute_output(x, x_lod, dimension)
31+
32+
self.outputs = {'Out': (out, out_lod)}
33+
34+
def compute_output(self, x, x_lod, dimension):
35+
x_width = x.shape[1]
36+
out_lod = [[0]]
37+
for i in xrange(len(x_lod[0]) - 1):
38+
seq_len = x_lod[0][i + 1] - x_lod[0][i]
39+
offset = (seq_len * x_width) / dimension
40+
assert int(offset) * dimension == seq_len * x_width
41+
out_lod[0].append(out_lod[0][-1] + int(offset))
42+
out = np.zeros(shape=(out_lod[0][-1], dimension)).astype('float32')
43+
out.ravel()[:] = x.ravel()[:]
44+
return out, out_lod
45+
46+
def test_check_output(self):
47+
self.check_output()
48+
49+
def test_check_grad(self):
50+
self.check_grad(["X"], "Out")
51+
52+
53+
class TestSequenceReshape_reduce(TestSequenceReshape):
54+
def setUp(self):
55+
self.op_type = 'sequence_reshape'
56+
dimension = 24
57+
x_lod = [[0, 4, 6, 8, 12]]
58+
x = np.random.uniform(0.1, 1, [12, 12]).astype('float32')
59+
60+
self.inputs = {'X': (x, x_lod)}
61+
self.attrs = {'new_dim': dimension}
62+
63+
out, out_lod = self.compute_output(x, x_lod, dimension)
64+
65+
self.outputs = {'Out': (out, out_lod)}
66+
67+
68+
class TestSequenceReshape_same(TestSequenceReshape):
69+
def setUp(self):
70+
self.op_type = 'sequence_reshape'
71+
dimension = 12
72+
x_lod = [[0, 4, 6, 8, 12]]
73+
x = np.random.uniform(0.1, 1, [12, 12]).astype('float32')
74+
75+
self.inputs = {'X': (x, x_lod)}
76+
self.attrs = {'new_dim': dimension}
77+
78+
out, out_lod = self.compute_output(x, x_lod, dimension)
79+
80+
self.outputs = {'Out': (out, out_lod)}
81+
82+
83+
if __name__ == '__main__':
84+
unittest.main()

0 commit comments

Comments
 (0)