Skip to content

Commit 9a18e78

Browse files
author
wanghaox
committed
update sequence slice op, fix some error
1 parent 29c2582 commit 9a18e78

File tree

3 files changed

+26
-15
lines changed

3 files changed

+26
-15
lines changed

paddle/operators/sequence_slice_op.cc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,17 @@ class SequenceSliceOpMaker : public framework::OpProtoAndCheckerMaker {
7575
"the input of SequenceSliceOp.");
7676
AddInput("Offset",
7777
"(Tensor), "
78-
"A vector<int> to describes offset for sub sequence item.");
78+
"a vector<int> to describe the offset of every input sequence for "
79+
"sub sequence item.");
7980
AddInput("Length",
8081
"(Tensor), "
81-
"A vector<int> to describes length for sub sequence item.");
82+
"a vector<int> to describe the length of every input sequence for "
83+
"sub sequence item.");
8284
AddOutput("Out",
83-
"(LoDTensor), output of sequence slice Op.");
85+
"(LoDTensor), The output of SequenceSliceOp.");
8486
AddComment(R"DOC(
8587
Sequence slice operator
88+
8689
The operator crop a subsequence from given sequence with given start offset and subsequence length.
8790
It only supports sequence (LoD Tensor with level number is 1).
8891
- Case:
@@ -91,13 +94,13 @@ It only supports sequence (LoD Tensor with level number is 1).
9194
c1, c2]
9295
[d1, d2;
9396
e1, e2]]
94-
LoD(X) = {{0, 3, 5}}; Dims(X) = (4, 1, 2)
95-
Offset = (0, 1); Length = (2, 1)
97+
LoD(X) = {{0, 3, 5}}; Dims(X) = (5, 2)
98+
Offset = [0, 1]; Length = [2, 1]
9699
97100
Out = [[a1, a2;
98101
b1, b2]
99102
[e1, e2]]
100-
LoD(Out) = {{0, 2, 3}}
103+
LoD(Out) = {{0, 2, 3}}; Dims(Out) = (3, 2)
101104
NOTE: The length of the input, offset and length should be the same. The offset start from 0.
102105
)DOC");
103106
}

paddle/operators/sequence_slice_op.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,10 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
8787

8888
out->mutable_data<T>(ctx.GetPlace());
8989
auto out_lod = SequenceSliceLoD(*in, offset_data, length_data);
90+
auto out_dims = in->dims();
91+
out_dims[0] = out_lod[0][out_lod[0].size() - 1];
92+
out->Resize(out_dims);
9093
out->set_lod(out_lod);
91-
math::SetConstant<Place, T> set_zero;
92-
set_zero(ctx.device_context(), out, static_cast<T>(0));
9394

9495
auto in_stride = framework::stride(in->dims());
9596
auto out_stride = framework::stride(out->dims());

python/paddle/v2/framework/tests/test_sequence_slice_op.py renamed to python/paddle/v2/fluid/tests/test_sequence_slice_op.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,32 @@
55

66
class TestSequenceSliceOp(OpTest):
77
def set_data(self):
8+
self.init_test_case()
89
# only supprot one level LoD
9-
x = np.random.random((100, 3, 2)).astype('float32')
10-
lod = [[0, 20, 40, 60, 80, 100]]
11-
offset = np.array([1, 2, 3, 4, 5]).flatten().astype("int64")
12-
length = np.array([10, 8, 6, 4, 2]).flatten().astype("int64")
10+
x = np.random.random(self.x_dim).astype('float32')
11+
lod = self.x_lod
12+
offset = np.array(self.offset).flatten().astype("int64")
13+
length = np.array(self.length).flatten().astype("int64")
1314

1415
self.inputs = {'X': (x, lod), 'Offset': offset, 'Length': length}
15-
outs = np.zeros((100, 3, 2)).astype('float32')
16+
outs = [] #np.zeros((100, 3, 2)).astype('float32')
1617
out_lod = [[0]]
1718
out_lod_offset = 0
1819
for i in range(len(offset)):
1920
sub_x = x[lod[0][i] + offset[i]: lod[0]
2021
[i] + offset[i] + length[i], :]
2122
out_lod_offset = out_lod_offset + len(sub_x)
22-
outs[out_lod[0][i]: out_lod_offset, :] = sub_x
23+
outs.append(sub_x)
2324
out_lod[0].append(out_lod_offset)
24-
25+
outs = np.concatenate(outs, axis=0)
2526
self.outputs = {'Out': (outs, out_lod)}
2627

28+
def init_test_case(self):
29+
self.x_dim = (100, 3, 2)
30+
self.x_lod = [[0, 20, 40, 60, 80, 100]]
31+
self.offset = [1, 2, 3, 4, 5]
32+
self.length = [10, 8, 6, 4, 2]
33+
2734
def setUp(self):
2835
self.op_type = "sequence_slice"
2936
self.set_data()

0 commit comments

Comments
 (0)