Skip to content

Commit 08cb472

Browse files
committed
Simplify the implementation.
1 parent fc581bc commit 08cb472

File tree

4 files changed

+68
-104
lines changed

4 files changed

+68
-104
lines changed

paddle/operators/sequence_reshape_op.cc

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/operators/sequence_reshape_op.h"
16+
#include "paddle/framework/ddim.h"
1617

1718
namespace paddle {
1819
namespace operators {
@@ -26,9 +27,11 @@ class SequenceReshapeOp : public framework::OperatorWithKernel {
2627
PADDLE_ENFORCE(ctx->HasOutput("Out"),
2728
"Output(Out) of SequenceReshapeOp should not be null.");
2829
auto x_dims = ctx->GetInputDim("X");
30+
auto x_numel = product(x_dims);
2931
PADDLE_ENFORCE_EQ(x_dims.size(), 2U, "Rank of Input(X) should be 2.");
30-
int dimension = ctx->Attrs().Get<int>("new_dim");
31-
ctx->SetOutputDim("Out", {x_dims[0], static_cast<int64_t>(dimension)});
32+
int new_dim = ctx->Attrs().Get<int>("new_dim");
33+
ctx->SetOutputDim("Out",
34+
{x_numel / new_dim, static_cast<int64_t>(new_dim)});
3235
}
3336
};
3437

@@ -54,16 +57,16 @@ example will help to illustrate the function of this operator:
5457
5558
x is a LoDTensor:
5659
x.lod = [[0, 2, 6]]
57-
x.data = [[0.1, 0.2], [0.3, 0.4],
58-
[0.5, 0.6], [0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]
60+
x.data = [[1, 2], [3, 4],
61+
[5, 6], [7, 8], [9, 10], [11, 12]]
5962
x.dims = [6, 2]
6063
6164
set new_dim = 4
6265
6366
then out is a LoDTensor:
64-
out.lod = [[0, 1, 3]]
65-
out.data = [[0.1, 0.2, 0.3, 0.4],
66-
[0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]]
67+
out.lod = [[0, 1, 3]]
68+
out.data = [[1, 2, 3, 4],
69+
[5, 6, 7, 8], [9, 10, 11, 12]]
6770
out.dims = [3, 4]
6871
6972
Currently, only 1-level LoDTensor is supported and please make sure (original
@@ -82,8 +85,6 @@ class SequenceReshapeGradOp : public framework::OperatorWithKernel {
8285
PADDLE_ENFORCE(
8386
ctx->HasInput(framework::GradVarName("Out")),
8487
"Input(Out@GRAD) of SequenceReshapeGradOp should not be null.");
85-
PADDLE_ENFORCE(ctx->HasInput("Out"),
86-
"Input(Out) of SequenceReshapeGradOp should not be null.");
8788
PADDLE_ENFORCE(ctx->HasInput("X"),
8889
"Input(X) of SequenceReshapeGradOp should not be null.");
8990

@@ -101,7 +102,6 @@ class SequenceReshapeGradOpMaker : public framework::SingleGradOpDescMaker {
101102
auto* op_desc_ptr = new framework::OpDesc();
102103
op_desc_ptr->SetType("sequence_reshape_grad");
103104
op_desc_ptr->SetInput("X", Input("X"));
104-
op_desc_ptr->SetInput("Out", Output("Out"));
105105
op_desc_ptr->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
106106
op_desc_ptr->SetOutput(framework::GradVarName("X"), InputGrad("X"));
107107
op_desc_ptr->SetAttrMap(Attrs());
@@ -118,7 +118,13 @@ REGISTER_OPERATOR(sequence_reshape, ops::SequenceReshapeOp,
118118
REGISTER_OPERATOR(sequence_reshape_grad, ops::SequenceReshapeGradOp);
119119
REGISTER_OP_CPU_KERNEL(
120120
sequence_reshape,
121-
ops::SequenceReshapeKernel<paddle::platform::CPUDeviceContext, float>);
121+
ops::SequenceReshapeKernel<paddle::platform::CPUDeviceContext, float>,
122+
ops::SequenceReshapeKernel<paddle::platform::CPUDeviceContext, double>,
123+
ops::SequenceReshapeKernel<paddle::platform::CPUDeviceContext, int>,
124+
ops::SequenceReshapeKernel<paddle::platform::CPUDeviceContext, int64_t>);
122125
REGISTER_OP_CPU_KERNEL(
123126
sequence_reshape_grad,
124-
ops::SequenceReshapeGradKernel<paddle::platform::CPUDeviceContext, float>);
127+
ops::SequenceReshapeGradKernel<paddle::platform::CPUDeviceContext, float>,
128+
ops::SequenceReshapeGradKernel<paddle::platform::CPUDeviceContext, double>,
129+
ops::SequenceReshapeGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
130+
ops::SequenceReshapeGradKernel<paddle::platform::CPUDeviceContext, int>);

paddle/operators/sequence_reshape_op.cu

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,14 @@ limitations under the License. */
1717
namespace ops = paddle::operators;
1818
REGISTER_OP_CUDA_KERNEL(
1919
sequence_reshape,
20-
ops::SequenceReshapeKernel<paddle::platform::CUDADeviceContext, float>);
20+
ops::SequenceReshapeKernel<paddle::platform::CUDADeviceContext, float>,
21+
ops::SequenceReshapeKernel<paddle::platform::CUDADeviceContext, double>,
22+
ops::SequenceReshapeKernel<paddle::platform::CUDADeviceContext, int>,
23+
ops::SequenceReshapeKernel<paddle::platform::CUDADeviceContext, int64_t>);
2124
REGISTER_OP_CUDA_KERNEL(
2225
sequence_reshape_grad,
23-
ops::SequenceReshapeGradKernel<paddle::platform::CUDADeviceContext, float>);
26+
ops::SequenceReshapeGradKernel<paddle::platform::CUDADeviceContext, float>,
27+
ops::SequenceReshapeGradKernel<paddle::platform::CUDADeviceContext, double>,
28+
ops::SequenceReshapeGradKernel<paddle::platform::CUDADeviceContext,
29+
int64_t>,
30+
ops::SequenceReshapeGradKernel<paddle::platform::CUDADeviceContext, int>);

paddle/operators/sequence_reshape_op.h

Lines changed: 25 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ class SequenceReshapeKernel : public framework::OpKernel<T> {
2828
auto* out = context.Output<LoDTensor>("Out");
2929
int out_width = context.Attr<int>("new_dim");
3030

31-
const T* p_in_data = in->data<T>();
32-
3331
auto in_dims = in->dims();
3432
int64_t in_width = in_dims[1];
3533
auto& in_lod = in->lod();
@@ -43,53 +41,29 @@ class SequenceReshapeKernel : public framework::OpKernel<T> {
4341
auto in_lod_l0 = in_lod[0];
4442
int seq_num = in_lod_l0.size() - 1;
4543

46-
auto& out_lod = *out->mutable_lod();
47-
out_lod.resize(1);
48-
out_lod[0].clear();
49-
out_lod[0].push_back(0);
50-
for (int i = 0; i < seq_num; ++i) {
51-
size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i];
52-
size_t offset = 0;
53-
offset = (seq_len * in_width) / out_width;
54-
PADDLE_ENFORCE_EQ(offset * out_width, seq_len * in_width,
55-
"Please make sure (sequence_length * dimension) can be "
56-
"divided by new_dim with no remainder for each "
57-
"sequence. The %dth sequence is invalid.",
58-
i + 1);
59-
PADDLE_ENFORCE_GT(offset, 0,
60-
"Illegal operation, length of the %dth sequence become "
61-
"to 0 after reshaped.",
62-
i + 1);
63-
out_lod[0].push_back(out_lod[0].back() + offset);
44+
if (in_width == out_width) {
45+
out->set_lod(in->lod());
46+
} else {
47+
auto& out_lod = *out->mutable_lod();
48+
out_lod.resize(1);
49+
out_lod[0].clear();
50+
out_lod[0].push_back(0);
51+
for (int i = 0; i < seq_num; ++i) {
52+
size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i];
53+
size_t offset = 0;
54+
offset = (seq_len * in_width) / out_width;
55+
PADDLE_ENFORCE_EQ(offset * out_width, seq_len * in_width,
56+
"Please make sure (sequence_length * dimension) can "
57+
"be divided by new_dim with no remainder for each "
58+
"sequence. The %dth sequence is invalid.",
59+
i + 1);
60+
out_lod[0].push_back(out_lod[0].back() + offset);
61+
}
6462
}
6563

6664
out->mutable_data<T>(context.GetPlace());
67-
out->Resize({static_cast<int64_t>(out_lod[0].back()), out_width});
68-
T* p_out_data = out->mutable_data<T>(context.GetPlace());
69-
math::set_constant(context.device_context(), out, 0.0f);
70-
71-
for (int i = 0; i < seq_num; ++i) {
72-
size_t in_offset = in_lod_l0[i] * in_width;
73-
size_t out_offset = out_lod[0][i] * out_width;
74-
size_t in_count = (in_lod_l0[i + 1] - in_lod_l0[i]) * in_width;
75-
size_t out_count = (out_lod[0][i + 1] - out_lod[0][i]) * out_width;
76-
size_t bytes = sizeof(T) * std::min(in_count, out_count);
77-
if (platform::is_cpu_place(context.GetPlace())) {
78-
memory::Copy(boost::get<platform::CPUPlace>(context.GetPlace()),
79-
p_out_data + out_offset,
80-
boost::get<platform::CPUPlace>(context.GetPlace()),
81-
p_in_data + in_offset, bytes);
82-
} else {
83-
#ifdef PADDLE_WITH_CUDA
84-
auto& dev_ctx =
85-
context.template device_context<platform::CUDADeviceContext>();
86-
memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
87-
p_out_data + out_offset,
88-
boost::get<platform::CUDAPlace>(context.GetPlace()),
89-
p_in_data + in_offset, bytes, dev_ctx.stream());
90-
#endif
91-
}
92-
}
65+
framework::Copy(*in, context.GetPlace(), out);
66+
out->Resize({static_cast<int64_t>(out->lod()[0].back()), out_width});
9367
}
9468
};
9569

@@ -98,45 +72,14 @@ class SequenceReshapeGradKernel : public framework::OpKernel<T> {
9872
public:
9973
void Compute(const framework::ExecutionContext& context) const override {
10074
auto* x_tensor_ptr = context.Input<LoDTensor>("X");
101-
auto* out_tensor_ptr = context.Input<LoDTensor>("Out");
102-
auto* out_grad_tensor_ptr =
75+
auto* outg_tensor_ptr =
10376
context.Input<LoDTensor>(framework::GradVarName("Out"));
104-
auto* x_grad_tensor_ptr =
77+
auto* xg_tensor_ptr =
10578
context.Output<LoDTensor>(framework::GradVarName("X"));
10679

107-
T* p_x_grad_data = x_grad_tensor_ptr->mutable_data<T>(context.GetPlace());
108-
const T* p_out_grad_data = out_grad_tensor_ptr->data<T>();
109-
110-
auto& x_lod = x_tensor_ptr->lod();
111-
int seq_num = x_lod[0].size() - 1;
112-
int x_width = x_tensor_ptr->dims()[1];
113-
auto& out_lod = out_tensor_ptr->lod();
114-
int out_width = out_tensor_ptr->dims()[1];
115-
116-
math::set_constant(context.device_context(), x_grad_tensor_ptr, 0.0f);
117-
118-
for (int i = 0; i < seq_num; ++i) {
119-
size_t src_offset = out_lod[0][i] * out_width;
120-
size_t dst_offset = x_lod[0][i] * x_width;
121-
size_t src_count = (out_lod[0][i + 1] - out_lod[0][i]) * out_width;
122-
size_t dst_count = (x_lod[0][i + 1] - x_lod[0][i]) * x_width;
123-
size_t bytes = sizeof(T) * std::min(src_count, dst_count);
124-
if (platform::is_cpu_place(context.GetPlace())) {
125-
memory::Copy(boost::get<platform::CPUPlace>(context.GetPlace()),
126-
p_x_grad_data + dst_offset,
127-
boost::get<platform::CPUPlace>(context.GetPlace()),
128-
p_out_grad_data + src_offset, bytes);
129-
} else {
130-
#ifdef PADDLE_WITH_CUDA
131-
auto& dev_ctx =
132-
context.template device_context<platform::CUDADeviceContext>();
133-
memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
134-
p_x_grad_data + dst_offset,
135-
boost::get<platform::CUDAPlace>(context.GetPlace()),
136-
p_out_grad_data + src_offset, bytes, dev_ctx.stream());
137-
#endif
138-
}
139-
}
80+
xg_tensor_ptr->mutable_data<T>(context.GetPlace());
81+
framework::Copy(*outg_tensor_ptr, context.GetPlace(), xg_tensor_ptr);
82+
xg_tensor_ptr->Resize(x_tensor_ptr->dims());
14083
}
14184
};
14285

python/paddle/v2/fluid/tests/test_sequence_reshape.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,7 @@ def compute_output(self, x, x_lod, dimension):
4040
assert int(offset) * dimension == seq_len * x_width
4141
out_lod[0].append(out_lod[0][-1] + int(offset))
4242
out = np.zeros(shape=(out_lod[0][-1], dimension)).astype('float32')
43-
for i in xrange(len(x_lod[0]) - 1):
44-
x_offset = x_lod[0][i] * x_width
45-
out_offset = out_lod[0][i] * dimension
46-
out_count = (out_lod[0][i + 1] - out_lod[0][i]) * dimension
47-
x_count = (x_lod[0][i + 1] - x_lod[0][i]) * x_width
48-
count = min(out_count, x_count)
49-
out.ravel()[out_offset:out_offset + count] = x.ravel()[
50-
x_offset:x_offset + count]
43+
out.ravel()[:] = x.ravel()[:]
5144
return out, out_lod
5245

5346
def test_check_output(self):
@@ -72,5 +65,20 @@ def setUp(self):
7265
self.outputs = {'Out': (out, out_lod)}
7366

7467

68+
class TestSequenceReshape_same(TestSequenceReshape):
69+
def setUp(self):
70+
self.op_type = 'sequence_reshape'
71+
dimension = 12
72+
x_lod = [[0, 4, 6, 8, 12]]
73+
x = np.random.uniform(0.1, 1, [12, 12]).astype('float32')
74+
75+
self.inputs = {'X': (x, x_lod)}
76+
self.attrs = {'new_dim': dimension}
77+
78+
out, out_lod = self.compute_output(x, x_lod, dimension)
79+
80+
self.outputs = {'Out': (out, out_lod)}
81+
82+
7583
if __name__ == '__main__':
7684
unittest.main()

0 commit comments

Comments
 (0)