@@ -48,42 +48,42 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
48
48
auto * length = ctx.Input <Tensor>(" Length" );
49
49
auto * out = ctx.Output <LoDTensor>(" Out" );
50
50
51
+ auto lod = in->lod ();
52
+ auto n = lod[0 ].size () - 1 ;
53
+
54
+ PADDLE_ENFORCE_EQ (lod.size (), 1UL ,
55
+ " Only support one level sequence now." );
56
+ PADDLE_ENFORCE_EQ (
57
+ n, length->dims ()[0 ],
58
+ " The size of input-sequence and length-array should be the same" )
59
+ PADDLE_ENFORCE_EQ (
60
+ n, offset->dims ()[0 ],
61
+ " The size of input-sequence and offset-array should be the same" )
62
+
51
63
const int64_t * offset_data = offset->data <int64_t >();
52
64
const int64_t * length_data = length->data <int64_t >();
65
+ framework::Tensor offset_cpu;
66
+ framework::Tensor length_cpu;
53
67
54
68
if (platform::is_gpu_place (ctx.GetPlace ())) {
55
- framework::Tensor offset_cpu;
56
69
offset_cpu.mutable_data <T>(offset->dims (), platform::CPUPlace ());
57
70
offset_cpu.CopyFrom (*offset, platform::CPUPlace (), ctx.device_context ());
58
71
offset_data = offset_cpu.data <int64_t >();
59
72
60
- framework::Tensor length_cpu;
61
73
length_cpu.mutable_data <T>(length->dims (), platform::CPUPlace ());
62
74
length_cpu.CopyFrom (*length, platform::CPUPlace (), ctx.device_context ());
63
75
length_data = length_cpu.data <int64_t >();
64
76
}
65
77
66
- auto lod = in->lod ();
67
- auto n = lod[0 ].size () - 1 ;
68
-
69
- PADDLE_ENFORCE_EQ (lod.size (), 1UL , " Only support one level sequence now." );
70
- PADDLE_ENFORCE_EQ (offset->dims ().size (), 1UL ,
71
- " Only support one level sequence now." );
72
- PADDLE_ENFORCE_EQ (length->dims ().size (), 1UL ,
73
- " Only support one level sequence now." );
74
- PADDLE_ENFORCE_EQ (
75
- n, length->dims ()[0 ],
76
- " The size of input-sequence and length-array should be the same" )
77
- PADDLE_ENFORCE_EQ (
78
- n, offset->dims ()[0 ],
79
- " The size of input-sequence and offset-array should be the same" )
80
-
81
78
for (size_t i = 0 ; i < n; ++i) {
82
- PADDLE_ENFORCE_LT (0 , offset_data[i], " The offset must greater than zero" )
83
- PADDLE_ENFORCE_LT (0 , length_data[i], " The length must greater than zero" )
84
- PADDLE_ENFORCE_LT (lod[0 ][i] + offset_data[i] + length_data[i],
85
- lod[0 ][i + 1 ], " The target tensor's length overflow" )
86
- }
79
+ PADDLE_ENFORCE_LT (0 , offset_data[i],
80
+ " The offset must greater than zero" )
81
+ PADDLE_ENFORCE_LT (0 , length_data[i],
82
+ " The length must greater than zero" )
83
+ PADDLE_ENFORCE_LT (
84
+ lod[0 ][i] + offset_data[i] + length_data[i],
85
+ lod[0 ][i + 1 ],
86
+ " The target tensor's length overflow" )}
87
87
88
88
out->mutable_data <T>(ctx.GetPlace ());
89
89
auto out_lod = SequenceSliceLoD (*in, offset_data, length_data);
@@ -100,7 +100,7 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
100
100
Tensor in_t =
101
101
in->Slice (static_cast <int >(lod[0 ][i] + offset_data[i]),
102
102
static_cast <int >(lod[0 ][i] + offset_data[i] +
103
- length_data[i]));
103
+ length_data[i]));
104
104
105
105
StridedMemcpy<T>(ctx.device_context (), in_t .data <T>(),
106
106
in_stride, in_t .dims (), out_stride,
0 commit comments