@@ -34,28 +34,26 @@ class SequenceUnpadOpKernel : public framework::OpKernel<T> {
3434 auto * len_t = ctx.Input <LoDTensor>(" Length" );
3535 auto * out_t = ctx.Output <LoDTensor>(" Out" );
3636
37- const int64_t * seq_len_ptr = nullptr ;
37+ auto & dev_ctx = ctx.template device_context <DeviceContext>();
38+ framework::Tensor seq_len_cpu =
39+ ctx.AllocateTmpTensor <T, DeviceContext>(len_t ->dims (), dev_ctx);
3840 if (platform::is_gpu_place (ctx.GetPlace ())) {
39- LoDTensor seq_len_cpu;
40- seq_len_cpu.Resize (len_t ->dims ());
41- seq_len_ptr = seq_len_cpu.mutable_data <int64_t >(platform::CPUPlace ());
42- framework::TensorCopy (*len_t , platform::CPUPlace (),
43- ctx.template device_context <DeviceContext>(),
44- &seq_len_cpu);
41+ seq_len_cpu.mutable_data <int64_t >(platform::CPUPlace ());
42+ framework::TensorCopySync (*len_t , platform::CPUPlace (), &seq_len_cpu);
4543 } else {
46- seq_len_ptr = len_t -> data < int64_t >() ;
44+ seq_len_cpu = * len_t ;
4745 }
4846
49- size_t batch_size = x_t ->dims ()[0 ];
47+ const int64_t * seq_len_ptr = seq_len_cpu.data <int64_t >();
48+ int64_t batch_size = len_t ->dims ()[0 ];
5049 std::vector<size_t > out_lod0 (batch_size + 1 , 0 );
51- for (size_t i = 0 ; i < batch_size; ++i) {
52- out_lod0[i + 1 ] = out_lod0[i] + seq_len_ptr[i];
50+ for (int64_t i = 0 ; i < batch_size; ++i) {
51+ out_lod0[i + 1 ] = out_lod0[i] + static_cast < size_t >( seq_len_ptr[i]) ;
5352 }
5453
5554 framework::LoD out_lod;
5655 out_lod.push_back (out_lod0);
5756 out_t ->set_lod (out_lod);
58-
5957 std::vector<int64_t > out_dims_vec{static_cast <int64_t >(out_lod0.back ())};
6058 if (x_t ->dims ().size () == 2 ) {
6159 out_dims_vec.push_back (1 );
@@ -71,8 +69,7 @@ class SequenceUnpadOpKernel : public framework::OpKernel<T> {
7169
7270 int64_t padded_length = x_t ->dims ()[1 ];
7371 math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
74- ctx.template device_context <DeviceContext>(), *x_t , out_t ,
75- padded_length, 0 , false , math::kBatchLengthWidth );
72+ dev_ctx, *x_t , out_t , padded_length, 0 , false , math::kBatchLengthWidth );
7673 }
7774};
7875
0 commit comments