Skip to content

Commit 1d95173

Browse files
author
wanghaox
committed
change offset and length's rank to 2, dim[0] for batch size
2 parents 40a6c48 + a76b614 commit 1d95173

File tree

5 files changed

+125
-113
lines changed

5 files changed

+125
-113
lines changed

paddle/operators/sequence_slice_op.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@ class SequenceSliceOp : public framework::OperatorWithKernel {
3232
"Output(Out) of SequenceSliceOp should not be null.");
3333
auto input_dims = ctx->GetInputDim("X");
3434

35+
auto offset_dim = ctx->GetInputDim("Offset");
36+
auto length_dim = ctx->GetInputDim("Length");
37+
38+
PADDLE_ENFORCE_EQ(offset_dim.size(), 2UL,
39+
"Only support one level sequence now.");
40+
PADDLE_ENFORCE_EQ(length_dim.size(), 2UL,
41+
"Only support one level sequence now.");
42+
3543
ctx->SetOutputDim("Out", input_dims);
3644
}
3745

@@ -95,7 +103,7 @@ It only supports sequence (LoD Tensor with level number is 1).
95103
[d1, d2;
96104
e1, e2]]
97105
LoD(X) = {{0, 3, 5}}; Dims(X) = (5, 2)
98-
Offset = [0, 1]; Length = [2, 1]
106+
Offset = [[0], [1]]; Length = [[2], [1]]
99107
100108
Out = [[a1, a2;
101109
b1, b2]

paddle/operators/sequence_slice_op.h

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -48,42 +48,42 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
4848
auto* length = ctx.Input<Tensor>("Length");
4949
auto* out = ctx.Output<LoDTensor>("Out");
5050

51+
auto lod = in->lod();
52+
auto n = lod[0].size() - 1;
53+
54+
PADDLE_ENFORCE_EQ(lod.size(), 1UL,
55+
"Only support one level sequence now.");
56+
PADDLE_ENFORCE_EQ(
57+
n, length->dims()[0],
58+
"The size of input-sequence and length-array should be the same")
59+
PADDLE_ENFORCE_EQ(
60+
n, offset->dims()[0],
61+
"The size of input-sequence and offset-array should be the same")
62+
5163
const int64_t* offset_data = offset->data<int64_t>();
5264
const int64_t* length_data = length->data<int64_t>();
65+
framework::Tensor offset_cpu;
66+
framework::Tensor length_cpu;
5367

5468
if (platform::is_gpu_place(ctx.GetPlace())) {
55-
framework::Tensor offset_cpu;
5669
offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace());
5770
offset_cpu.CopyFrom(*offset, platform::CPUPlace(), ctx.device_context());
5871
offset_data = offset_cpu.data<int64_t>();
5972

60-
framework::Tensor length_cpu;
6173
length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace());
6274
length_cpu.CopyFrom(*length, platform::CPUPlace(), ctx.device_context());
6375
length_data = length_cpu.data<int64_t>();
6476
}
6577

66-
auto lod = in->lod();
67-
auto n = lod[0].size() - 1;
68-
69-
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
70-
PADDLE_ENFORCE_EQ(offset->dims().size(), 1UL,
71-
"Only support one level sequence now.");
72-
PADDLE_ENFORCE_EQ(length->dims().size(), 1UL,
73-
"Only support one level sequence now.");
74-
PADDLE_ENFORCE_EQ(
75-
n, length->dims()[0],
76-
"The size of input-sequence and length-array should be the same")
77-
PADDLE_ENFORCE_EQ(
78-
n, offset->dims()[0],
79-
"The size of input-sequence and offset-array should be the same")
80-
8178
for (size_t i = 0; i < n; ++i) {
82-
PADDLE_ENFORCE_LT(0, offset_data[i], "The offset must greater than zero")
83-
PADDLE_ENFORCE_LT(0, length_data[i], "The length must greater than zero")
84-
PADDLE_ENFORCE_LT(lod[0][i] + offset_data[i] + length_data[i],
85-
lod[0][i + 1], "The target tensor's length overflow")
86-
}
79+
PADDLE_ENFORCE_LT(0, offset_data[i],
80+
"The offset must greater than zero")
81+
PADDLE_ENFORCE_LT(0, length_data[i],
82+
"The length must greater than zero")
83+
PADDLE_ENFORCE_LT(
84+
lod[0][i] + offset_data[i] + length_data[i],
85+
lod[0][i + 1],
86+
"The target tensor's length overflow")}
8787

8888
out->mutable_data<T>(ctx.GetPlace());
8989
auto out_lod = SequenceSliceLoD(*in, offset_data, length_data);
@@ -100,7 +100,7 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
100100
Tensor in_t =
101101
in->Slice(static_cast<int>(lod[0][i] + offset_data[i]),
102102
static_cast<int>(lod[0][i] + offset_data[i] +
103-
length_data[i]));
103+
length_data[i]));
104104

105105
StridedMemcpy<T>(ctx.device_context(), in_t.data<T>(),
106106
in_stride, in_t.dims(), out_stride,

0 commit comments

Comments
 (0)