Skip to content

Commit d50f518

Browse files
committed
bug fix
1 parent 9db40b8 commit d50f518

File tree

2 files changed

+26
-16
lines changed

2 files changed

+26
-16
lines changed

paddle/fluid/operators/math/sequence_padding.cu

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ namespace math {
2222
template <typename T, CopyType Type>
2323
__global__ void SequencePaddingKernel(
2424
T* dst, const T* src, const T* pad_value, bool is_constant_pad,
25-
const size_t* seq_offsets, const size_t& seq_num, const size_t& pad_seq_len,
26-
const size_t& step_width, bool norm_by_len, const PadLayout& layout) {
25+
const size_t* seq_offsets, const size_t seq_num, const size_t pad_seq_len,
26+
const size_t step_width, bool norm_by_len, const PadLayout layout) {
2727
size_t seq_idx = blockIdx.y;
2828
size_t seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx];
2929

@@ -43,7 +43,7 @@ __global__ void SequencePaddingKernel(
4343
dst_data[i] = scale * src_data[i];
4444
}
4545
} else if (step_idx < pad_seq_len && Type == kSeqToPad) {
46-
for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) {
46+
for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) {
4747
dst_data[i] = is_constant_pad ? pad_value[0] : pad_value[i];
4848
}
4949
}
@@ -54,33 +54,34 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
5454
public:
5555
void operator()(const platform::CUDADeviceContext& context,
5656
const framework::LoDTensor& seq_tensor,
57-
framework::Tensor* pad_tensor,
57+
framework::LoDTensor* pad_tensor,
5858
const framework::LoDTensor& pad_value, int pad_seq_len = -1,
5959
int lod_level = 0, bool norm_by_times = false,
6060
const PadLayout layout = kBatchLengthWidth) {
6161
auto seq_lod = seq_tensor.lod();
6262
const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level];
6363
const auto& seq_tensor_dims = seq_tensor.dims();
6464
const auto& pad_tensor_dims = pad_tensor->dims();
65+
int max_seq_len = MaximumSequenceLength(seq_offsets);
6566
if (pad_seq_len == -1) {
66-
pad_seq_len = MaximumSequenceLength(seq_offsets);
67+
pad_seq_len = max_seq_len;
6768
}
6869
int step_width = seq_tensor.numel() / seq_tensor_dims[0];
69-
int seq_num = seq_offset.size() - 1;
70+
int seq_num = seq_offsets.size() - 1;
7071

7172
CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
7273
step_width, layout);
7374
PADDLE_ENFORCE(pad_value.numel() == 1 || pad_value.numel() == step_width,
7475
"The numel of 'pad_value' can only be 1 or be equal to the "
7576
"'step_width'.");
7677

77-
if (!norm_by_times && seq_num == 1UL && pad_seq_len == -1) {
78+
if (!norm_by_times && seq_num == 1UL && pad_seq_len == max_seq_len) {
7879
TensorCopy(seq_tensor, context.GetPlace(), context, pad_tensor);
7980
pad_tensor->Resize(pad_tensor_dims);
8081
return;
8182
}
8283

83-
const int64_t kBlockSize = 512;
84+
const int kBlockSize = 512;
8485

8586
/* At least use 32 threads to copy sequence_width elements,
8687
* and at least 8 elements for each thread.
@@ -100,8 +101,16 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
100101

101102
SequencePaddingKernel<T, kSeqToPad><<<grid, threads, 0, context.stream()>>>(
102103
pad_data, seq_data, pad_value_data, pad_value.numel() == 1,
103-
seq_offset.CUDAData(context.GetPlace()), seq_num, pad_seq_len,
104+
seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len,
104105
step_width, norm_by_times, layout);
106+
107+
if (layout == kBatchLengthWidth) {
108+
framework::LoD pad_lod(seq_lod.begin() + lod_level, seq_lod.end());
109+
for (size_t i = 0; i < pad_lod[0].size(); ++i) {
110+
pad_lod[0][i] = i * pad_seq_len;
111+
}
112+
pad_tensor->set_lod(pad_lod);
113+
}
105114
}
106115
};
107116

@@ -116,22 +125,23 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
116125
auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level];
117126
const auto& seq_tensor_dims = seq_tensor->dims();
118127
const auto& pad_tensor_dims = pad_tensor.dims();
128+
int max_seq_len = MaximumSequenceLength(seq_offsets);
119129
if (pad_seq_len == -1) {
120-
pad_seq_len = MaximumSequenceLength(seq_offsets);
130+
pad_seq_len = max_seq_len;
121131
}
122132
int step_width = seq_tensor->numel() / seq_tensor_dims[0];
123-
int seq_num = seq_offset.size() - 1;
133+
int seq_num = seq_offsets.size() - 1;
124134

125135
CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
126136
step_width, layout);
127137

128-
if (!norm_by_times && seq_num == 1UL && pad_seq_len == -1) {
138+
if (!norm_by_times && seq_num == 1UL && pad_seq_len == max_seq_len) {
129139
TensorCopy(pad_tensor, context.GetPlace(), context, seq_tensor);
130140
seq_tensor->Resize(seq_tensor_dims);
131141
return;
132142
}
133143

134-
const int64_t kBlockSize = 512;
144+
const int kBlockSize = 512;
135145

136146
/* At least use 32 threads to copy sequence_width elements,
137147
* and at least 8 elements for each thread.
@@ -150,7 +160,7 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
150160

151161
SequencePaddingKernel<T, kPadToSeq><<<grid, threads, 0, context.stream()>>>(
152162
seq_data, pad_data, nullptr, false,
153-
seq_offset.CUDAData(context.GetPlace()), seq_num, pad_seq_len,
163+
seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len,
154164
step_width, norm_by_times, layout);
155165
}
156166
};

paddle/fluid/operators/math/sequence_padding.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ inline static void CheckDims(const framework::DDim& seq_tensor_dims,
8181
template <typename DeviceContext, typename T>
8282
class PaddingLoDTensorFunctor {
8383
public:
84-
void operator()(const platform::CPUDeviceContext& context,
84+
void operator()(const DeviceContext& context,
8585
const framework::LoDTensor& seq_tensor,
8686
framework::LoDTensor* pad_tensor,
8787
const framework::LoDTensor& pad_value, int pad_seq_len = -1,
@@ -92,7 +92,7 @@ class PaddingLoDTensorFunctor {
9292
template <typename DeviceContext, typename T>
9393
class UnpaddingLoDTensorFunctor {
9494
public:
95-
void operator()(const platform::CPUDeviceContext& context,
95+
void operator()(const DeviceContext& context,
9696
const framework::LoDTensor& pad_tensor,
9797
framework::LoDTensor* seq_tensor, int pad_seq_len = -1,
9898
int lod_level = 0, bool norm_by_times = false,

0 commit comments

Comments
 (0)