Skip to content

Commit 45bd589

Browse files
author
Yibing Liu
authored
Fix the bug of sequence_unpad op (#18290) (#18305)
* Use TensorCopySync for sequence_unpad op * Fix the tensor memory alloc bug test=release/1.5
1 parent 129f271 commit 45bd589

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

paddle/fluid/operators/sequence_ops/sequence_unpad_op.h

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,28 +34,26 @@ class SequenceUnpadOpKernel : public framework::OpKernel<T> {
3434
auto* len_t = ctx.Input<LoDTensor>("Length");
3535
auto* out_t = ctx.Output<LoDTensor>("Out");
3636

37-
const int64_t* seq_len_ptr = nullptr;
37+
auto& dev_ctx = ctx.template device_context<DeviceContext>();
38+
framework::Tensor seq_len_cpu =
39+
ctx.AllocateTmpTensor<T, DeviceContext>(len_t->dims(), dev_ctx);
3840
if (platform::is_gpu_place(ctx.GetPlace())) {
39-
LoDTensor seq_len_cpu;
40-
seq_len_cpu.Resize(len_t->dims());
41-
seq_len_ptr = seq_len_cpu.mutable_data<int64_t>(platform::CPUPlace());
42-
framework::TensorCopy(*len_t, platform::CPUPlace(),
43-
ctx.template device_context<DeviceContext>(),
44-
&seq_len_cpu);
41+
seq_len_cpu.mutable_data<int64_t>(platform::CPUPlace());
42+
framework::TensorCopySync(*len_t, platform::CPUPlace(), &seq_len_cpu);
4543
} else {
46-
seq_len_ptr = len_t->data<int64_t>();
44+
seq_len_cpu = *len_t;
4745
}
4846

49-
size_t batch_size = x_t->dims()[0];
47+
const int64_t* seq_len_ptr = seq_len_cpu.data<int64_t>();
48+
int64_t batch_size = len_t->dims()[0];
5049
std::vector<size_t> out_lod0(batch_size + 1, 0);
51-
for (size_t i = 0; i < batch_size; ++i) {
52-
out_lod0[i + 1] = out_lod0[i] + seq_len_ptr[i];
50+
for (int64_t i = 0; i < batch_size; ++i) {
51+
out_lod0[i + 1] = out_lod0[i] + static_cast<size_t>(seq_len_ptr[i]);
5352
}
5453

5554
framework::LoD out_lod;
5655
out_lod.push_back(out_lod0);
5756
out_t->set_lod(out_lod);
58-
5957
std::vector<int64_t> out_dims_vec{static_cast<int64_t>(out_lod0.back())};
6058
if (x_t->dims().size() == 2) {
6159
out_dims_vec.push_back(1);
@@ -71,8 +69,7 @@ class SequenceUnpadOpKernel : public framework::OpKernel<T> {
7169

7270
int64_t padded_length = x_t->dims()[1];
7371
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
74-
ctx.template device_context<DeviceContext>(), *x_t, out_t,
75-
padded_length, 0, false, math::kBatchLengthWidth);
72+
dev_ctx, *x_t, out_t, padded_length, 0, false, math::kBatchLengthWidth);
7673
}
7774
};
7875

0 commit comments

Comments
 (0)