Skip to content

Commit 23b0388

Browse files
author
wanghaox
committed
add sub sequence operator code and unittest
1 parent ce08645 commit 23b0388

File tree

4 files changed

+320
-0
lines changed

4 files changed

+320
-0
lines changed

paddle/operators/sub_sequence_op.cc

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/sub_sequence_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class SubSequenceOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
void InferShape(framework::InferShapeContext* ctx) const override {
25+
PADDLE_ENFORCE(ctx->HasInput("X"),
26+
"Input(X) of SubSequenceOp should not be null.");
27+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
28+
"Output(Out) of SubSequenceOp should not be null.");
29+
auto input_dims = ctx->GetInputDim("X");
30+
31+
auto offsets = ctx->Attrs().Get<std::vector<int>>("offset");
32+
auto sizes = ctx->Attrs().Get<std::vector<int>>("size");
33+
34+
auto dim_0 = 0;
35+
for (size_t i = 0; i < sizes.size(); ++i) {
36+
dim_0 += sizes[i];
37+
}
38+
39+
framework::DDim out_dims = input_dims;
40+
out_dims[0] = dim_0;
41+
ctx->SetOutputDim("Out", out_dims);
42+
}
43+
};
44+
45+
class SubSequenceGradOp : public framework::OperatorWithKernel {
46+
public:
47+
using framework::OperatorWithKernel::OperatorWithKernel;
48+
49+
void InferShape(framework::InferShapeContext* ctx) const override {
50+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
51+
"The gradient of Out should not be null.");
52+
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")),
53+
"The gradient of X should not be null.");
54+
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
55+
}
56+
};
57+
58+
class SubSequenceOpMaker : public framework::OpProtoAndCheckerMaker {
59+
public:
60+
SubSequenceOpMaker(framework::OpProto* proto,
61+
framework::OpAttrChecker* op_checker)
62+
: OpProtoAndCheckerMaker(proto, op_checker) {
63+
AddInput("X", "(LoDTensor), "
64+
"the variable-length input of SubSequenceOp");
65+
AddAttr<std::vector<int>>(
66+
"offset",
67+
"A list<int> to describes offset for sub sequence item.");
68+
AddAttr<std::vector<int>>(
69+
"size",
70+
"A list<int> to describes size for sub sequence item.");
71+
AddOutput("Out",
72+
"(Tensor), Variable-length output of "
73+
"sequence_concat Op.");
74+
AddComment(R"DOC(
75+
Sub Sequence operator
76+
77+
The operator crop a subsequence from given sequence with given start offset and subsequence size.
78+
It only supports sequence (LoD Tensor with level number is 1).
79+
- Case:
80+
LoD(x) = {{0, 3, 6, 10}}; Dims(x0) = (10, 3, 2)
81+
offset = (0, 1, 1); size = (2, 1, 2)
82+
LoD(Out) = {{0, 2, 3, 5}}; Dims(Out) = (5,3,2)
83+
NOTE: The length of the input, offset and size should be the same. The offset start from 0.
84+
)DOC");
85+
}
86+
};
87+
88+
} // namespace operators
89+
} // namespace paddle
90+
91+
namespace ops = paddle::operators;
92+
REGISTER_OP(sub_sequence, ops::SubSequenceOp, ops::SubSequenceOpMaker,
93+
sub_sequence_grad, ops::SubSequenceGradOp);
94+
REGISTER_OP_CPU_KERNEL(
95+
sub_sequence,
96+
ops::SubSequenceOpKernel<paddle::platform::CPUPlace, float>);
97+
REGISTER_OP_CPU_KERNEL(
98+
sub_sequence_grad,
99+
ops::SubSequenceGradOpKernel<paddle::platform::CPUPlace, float>);

paddle/operators/sub_sequence_op.cu

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#define EIGEN_USE_GPU
16+
17+
#include "paddle/operators/sub_sequence_op.h"
18+
19+
namespace ops = paddle::operators;
20+
REGISTER_OP_GPU_KERNEL(
21+
sub_sequence,
22+
ops::SubSequenceOpKernel<paddle::platform::GPUPlace, float>);
23+
REGISTER_OP_GPU_KERNEL(
24+
sub_sequence_grad,
25+
ops::SubSequenceGradOpKernel<paddle::platform::GPUPlace, float>);

paddle/operators/sub_sequence_op.h

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/framework/eigen.h"
17+
#include "paddle/framework/op_registry.h"
18+
#include "paddle/operators/strided_memcpy.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
using Tensor = framework::Tensor;
24+
using LoDTensor = framework::LoDTensor;
25+
using LoD = framework::LoD;
26+
27+
template <typename T>
28+
LoD subsequenceLoD(const T* in, const std::vector<int> offsets,
29+
const std::vector<int> sizes) {
30+
auto out_lod = in->lod();
31+
size_t lod_offset = 0;
32+
33+
auto n = in->lod()[0].size() - 1;
34+
out_lod[0][0] = 0;
35+
for (size_t i = 0; i < n; ++i) {
36+
lod_offset += sizes[i];
37+
out_lod[0][i+1] = lod_offset;
38+
}
39+
return out_lod;
40+
}
41+
42+
template <typename Place, typename T>
43+
class SubSequenceOpKernel : public framework::OpKernel<T> {
44+
public:
45+
void Compute(const framework::ExecutionContext& ctx) const override {
46+
auto* in = ctx.Input<LoDTensor>("X");
47+
std::vector<int> offsets = ctx.Attr<std::vector<int>>("offset");
48+
std::vector<int> sizes = ctx.Attr<std::vector<int>>("size");
49+
auto* out = ctx.Output<LoDTensor>("Out");
50+
51+
auto offset_len = offsets.size();
52+
auto size_len = sizes.size();
53+
54+
auto lod = in->lod();
55+
auto n = lod[0].size() - 1;
56+
57+
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
58+
PADDLE_ENFORCE_EQ(n, offset_len,
59+
"The length of input and offset should be the same")
60+
PADDLE_ENFORCE_EQ(n, size_len,
61+
"The length of input and size should be the same")
62+
63+
for (size_t i = 0; i < n; ++i) {
64+
auto offset = offsets[i];
65+
auto size = sizes[i];
66+
PADDLE_ENFORCE_LT(lod[0][i] + offset + size, lod[0][i + 1],
67+
"The target tensor's length overflow")
68+
}
69+
70+
out->mutable_data<T>(ctx.GetPlace());
71+
auto out_lod = subsequenceLoD(in, offsets, sizes);
72+
out->set_lod(out_lod);
73+
74+
auto in_stride = framework::stride(in->dims());
75+
auto out_stride = framework::stride(out->dims());
76+
77+
size_t out_offset = 0;
78+
for (size_t i = 0; i < n; ++i) {
79+
auto offset = offsets[i];
80+
auto size = sizes[i];
81+
82+
Tensor in_t = in->Slice(static_cast<int>(lod[0][i] + offset),
83+
static_cast<int>(lod[0][i] + offset + size));
84+
85+
StridedMemcpy<T>(ctx.device_context(), in_t.data<T>(),
86+
in_stride, in_t.dims(), out_stride,
87+
out->data<T>() + out_offset);
88+
out_offset += size * in_stride[0];
89+
}
90+
}
91+
};
92+
93+
template <typename Place, typename T>
94+
class SubSequenceGradOpKernel : public framework::OpKernel<T> {
95+
public:
96+
void Compute(const framework::ExecutionContext& ctx) const override {
97+
auto* in = ctx.Input<LoDTensor>("X");
98+
std::vector<int> offsets = ctx.Attr<std::vector<int>>("offset");
99+
std::vector<int> sizes = ctx.Attr<std::vector<int>>("size");
100+
auto* out_grad =
101+
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
102+
auto* x_grad =
103+
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
104+
105+
auto offset_len = offsets.size();
106+
auto size_len = sizes.size();
107+
108+
auto lod = in->lod();
109+
auto n = lod[0].size() - 1;
110+
111+
// check input data format
112+
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
113+
PADDLE_ENFORCE_EQ(n, offset_len,
114+
"The length of input and offset should be the same")
115+
PADDLE_ENFORCE_EQ(n, size_len,
116+
"The length of input and size should be the same")
117+
118+
for (size_t i = 0; i < n; ++i) {
119+
auto offset = offsets[i];
120+
auto size = sizes[i];
121+
PADDLE_ENFORCE_LT(lod[0][i] + offset + size, lod[0][i + 1],
122+
"The target tensor's length overflow")
123+
}
124+
125+
auto out_lod = subsequenceLoD(in, offsets, sizes);
126+
127+
x_grad->set_lod(lod);
128+
x_grad->mutable_data<T>(ctx.GetPlace());
129+
auto temp = framework::EigenVector<T>::Flatten(*x_grad);
130+
temp.device(ctx.GetEigenDevice<Place>()) = temp.constant(static_cast<T>(0));
131+
132+
auto out_grad_stride = framework::stride(out_grad->dims());
133+
134+
for (size_t i = 0; i < out_lod[0].size() - 1; ++i) {
135+
Tensor out_grad_t =
136+
out_grad->Slice(static_cast<int>(out_lod[0][i]),
137+
static_cast<int>(out_lod[0][i + 1]));
138+
auto out_grad_stride = framework::stride(out_grad_t.dims());
139+
140+
auto x_grad_stride = framework::stride(x_grad->dims());
141+
142+
auto offset = offsets[i];
143+
auto size = sizes[i];
144+
145+
Tensor x_grad_t = x_grad->Slice(static_cast<int>(lod[0][i] + offset),
146+
static_cast<int>(lod[0][i] + offset + size));
147+
148+
StridedMemcpy<T>(ctx.device_context(), out_grad_t.data<T>(),
149+
out_grad_stride, out_grad_t.dims(), x_grad_stride,
150+
x_grad_t.data<T>());
151+
}
152+
}
153+
};
154+
155+
} // namespace operators
156+
} // namespace paddle
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import unittest
2+
import numpy as np
3+
import sys
4+
from op_test import OpTest
5+
6+
class TestSubSequenceOp(OpTest):
7+
def set_data(self):
8+
# only supprot one level LoD
9+
x = np.random.random((100, 3, 2)).astype('float32')
10+
lod = [[0, 20, 40, 60, 80, 100]]
11+
offsets = np.array([1, 2, 3, 4, 5]).flatten()
12+
sizes = np.array([10, 8, 6, 4, 2]).flatten()
13+
14+
self.inputs = {'X': (x, lod)}
15+
self.attrs = {'offset': offsets, 'size': sizes}
16+
outs = []
17+
out_lod = [[0]]
18+
out_lod_offset = 0
19+
for i in range(len(offsets)):
20+
sub_x = x[lod[0][i] + offsets[i]: lod[0]
21+
[i] + offsets[i] + sizes[i], :]
22+
outs.append(sub_x)
23+
out_lod_offset = out_lod_offset + len(sub_x)
24+
out_lod[0].append(out_lod_offset)
25+
26+
outs = np.concatenate(outs, axis=0)
27+
self.outputs = {'Out': outs}
28+
29+
def setUp(self):
30+
self.op_type = "sub_sequence"
31+
self.set_data()
32+
33+
def test_check_output(self):
34+
self.check_output()
35+
36+
def test_check_grad(self):
37+
self.check_grad(['X'], 'Out')
38+
39+
if __name__ == '__main__':
40+
unittest.main()

0 commit comments

Comments
 (0)