Skip to content

Commit 8b9938a

Browse files
pkuyymJiayiFeng
authored andcommitted
Refine code.
1 parent 5d90141 commit 8b9938a

File tree

6 files changed

+183
-169
lines changed

6 files changed

+183
-169
lines changed

paddle/fluid/operators/math/sequence_padding.cc

Lines changed: 63 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -18,111 +18,114 @@ namespace paddle {
1818
namespace operators {
1919
namespace math {
2020

21-
template <typename T, PaddingLayout padding_layout>
21+
template <typename T>
2222
void CopyDataCPU(framework::LoDTensor* seq_tensor,
23-
framework::Tensor* padding_tensor,
24-
const framework::Vector<size_t>& abs_offset,
23+
framework::Tensor* pad_tensor,
24+
const framework::Vector<size_t>& seq_offset,
2525
const int64_t& max_seq_len, const int64_t& seq_width,
26-
bool seq_to_padding, bool norm_by_len) {
26+
bool seq_to_pad, bool norm_by_len,
27+
OutputLayout output_layout) {
2728
T* seq_data = seq_tensor->data<T>();
28-
T* padding_data = padding_tensor->data<T>();
29+
T* pad_data = pad_tensor->data<T>();
2930

30-
int64_t seq_num = abs_offset.size() - 1;
31+
int64_t seq_num = seq_offset.size() - 1;
3132

3233
for (int64_t i = 0; i < seq_num; ++i) {
33-
int64_t seq_start = abs_offset[i];
34-
int64_t seq_len = abs_offset[i + 1] - seq_start;
35-
34+
int64_t seq_start = seq_offset[i];
35+
int64_t seq_len = seq_offset[i + 1] - seq_start;
3636
T scale = norm_by_len ? (1.0f / static_cast<T>(seq_len)) : 1.0f;
37-
3837
for (int64_t j = 0; j < seq_len; ++j) {
3938
for (int64_t k = 0; k < seq_width; ++k) {
40-
size_t padding_offset = 0;
41-
if (padding_layout == BATCH_LENGTH_WIDTH) {
42-
padding_offset = (i * max_seq_len * seq_width) + j * seq_width + k;
39+
size_t pad_data_idx = 0;
40+
size_t seq_data_idx = (seq_start + j) * seq_width + k;
41+
if (output_layout == kBatchLengthWidth) {
42+
pad_data_idx = (i * max_seq_len + j) * seq_width + k;
4343
} else {
44-
padding_offset = (j * seq_num * seq_width) + i * seq_width + k;
44+
pad_data_idx = (j * seq_num + i) * seq_width + k;
4545
}
46-
if (seq_to_padding) {
47-
padding_data[padding_offset] =
48-
seq_data[(seq_start + j) * seq_width + k] * scale;
46+
if (seq_to_pad) {
47+
pad_data[pad_data_idx] = seq_data[seq_data_idx] * scale;
4948
} else {
50-
seq_data[(seq_start + j) * seq_width + k] =
51-
padding_data[padding_offset] * scale;
49+
seq_data[seq_data_idx] = pad_data[pad_data_idx] * scale;
5250
}
5351
}
5452
}
5553
}
5654
}
5755

58-
template <typename T, PaddingLayout padding_layout>
59-
class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T, padding_layout> {
56+
template <typename T>
57+
class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
6058
public:
6159
void operator()(const platform::CPUDeviceContext& context,
6260
const framework::LoDTensor& seq_tensor,
63-
framework::Tensor* padding_tensor,
64-
T padding_value = static_cast<T>(0),
65-
bool norm_by_times = false, size_t lod_level = 0) {
66-
ValidateLoD(seq_tensor, lod_level);
61+
framework::Tensor* pad_tensor,
62+
T pad_value = static_cast<T>(0), bool norm_by_times = false,
63+
size_t lod_level = 0,
64+
OutputLayout output_layout = kBatchLengthWidth) {
65+
CheckLoD(seq_tensor, lod_level);
6766

6867
auto& lod = seq_tensor.lod();
69-
auto& abs_offset = framework::ToAbsOffset(lod)[lod_level];
68+
auto& seq_offset = framework::ToAbsOffset(lod)[lod_level];
7069

71-
auto seq_dims = seq_tensor.dims();
72-
auto padding_dims = padding_tensor->dims();
73-
int64_t max_seq_len = MaximumSequenceLength(lod, lod_level);
74-
int64_t seq_num = abs_offset.size() - 1;
75-
int64_t seq_width = seq_tensor.numel() / seq_dims[0];
76-
int64_t numel = max_seq_len * seq_num * seq_width;
70+
auto seq_tensor_dims = seq_tensor.dims();
71+
auto pad_tensor_dims = pad_tensor->dims();
72+
int64_t max_seq_len = MaximumSequenceLength(seq_offset);
73+
int64_t seq_num = seq_offset.size() - 1;
74+
int64_t seq_width = seq_tensor.numel() / seq_tensor_dims[0];
7775

78-
ValidateShape(seq_dims, abs_offset.back(), padding_dims, max_seq_len,
79-
seq_num, seq_width, padding_layout);
76+
CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len,
77+
seq_num, seq_width, output_layout);
8078

81-
T* padding_data = padding_tensor->data<T>();
79+
T* pad_data = pad_tensor->data<T>();
8280

83-
memset(padding_data, padding_value, numel * sizeof(T));
81+
memset(pad_data, pad_value, max_seq_len * seq_num * seq_width * sizeof(T));
8482

85-
CopyDataCPU<T, padding_layout>(
86-
const_cast<framework::LoDTensor*>(&seq_tensor), padding_tensor,
87-
abs_offset, max_seq_len, seq_width, true /* seq_to_padding */,
88-
norm_by_times);
83+
CopyDataCPU<T>(const_cast<framework::LoDTensor*>(&seq_tensor), pad_tensor,
84+
seq_offset, max_seq_len, seq_width, true /* seq_to_pad */,
85+
norm_by_times, output_layout);
8986
}
9087
};
9188

92-
template <typename T, PaddingLayout padding_layout>
93-
class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, T, padding_layout> {
89+
template <typename T>
90+
class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
9491
public:
9592
void operator()(const platform::CPUDeviceContext& context,
9693
framework::LoDTensor* seq_tensor,
97-
const framework::Tensor& padding_tensor,
98-
bool norm_by_times = false, size_t lod_level = 0) {
99-
ValidateLoD(*seq_tensor, lod_level);
94+
const framework::Tensor& pad_tensor,
95+
bool norm_by_times = false, size_t lod_level = 0,
96+
OutputLayout output_layout = kBatchLengthWidth) {
97+
CheckLoD(*seq_tensor, lod_level);
10098

10199
auto& lod = seq_tensor->lod();
102-
auto& abs_offset = framework::ToAbsOffset(lod)[lod_level];
100+
auto& seq_offset = framework::ToAbsOffset(lod)[lod_level];
103101

104-
auto& seq_dims = seq_tensor->dims();
105-
auto& padding_dims = padding_tensor.dims();
106-
int64_t max_seq_len = MaximumSequenceLength(lod, lod_level);
107-
int64_t seq_num = abs_offset.size() - 1;
108-
int64_t seq_width = seq_tensor->numel() / seq_dims[0];
102+
auto& seq_tensor_dims = seq_tensor->dims();
103+
auto& pad_tensor_dims = pad_tensor.dims();
104+
int64_t max_seq_len = MaximumSequenceLength(seq_offset);
105+
int64_t seq_num = seq_offset.size() - 1;
106+
int64_t seq_width = seq_tensor->numel() / seq_tensor_dims[0];
109107

110-
ValidateShape(seq_dims, abs_offset.back(), padding_dims, max_seq_len,
111-
seq_num, seq_width, padding_layout);
108+
CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len,
109+
seq_num, seq_width, output_layout);
112110

113111
T* seq_data = seq_tensor->data<T>();
114112
memset(seq_data, static_cast<T>(0), seq_tensor->numel() * sizeof(T));
115113

116-
CopyDataCPU<T, padding_layout>(
117-
seq_tensor, const_cast<framework::Tensor*>(&padding_tensor), abs_offset,
118-
max_seq_len, seq_width, false /* seq_to_padding */, norm_by_times);
114+
CopyDataCPU<T>(seq_tensor, const_cast<framework::Tensor*>(&pad_tensor),
115+
seq_offset, max_seq_len, seq_width, false /* seq_to_pad */,
116+
norm_by_times, output_layout);
119117
}
120118
};
121119

122-
template class PaddingLoDTensorFunctor<platform::CPUDeviceContext, float,
123-
LENGTH_BATCH_WIDTH>;
124-
template class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, float,
125-
LENGTH_BATCH_WIDTH>;
120+
template class PaddingLoDTensorFunctor<platform::CPUDeviceContext, int>;
121+
template class PaddingLoDTensorFunctor<platform::CPUDeviceContext, int64_t>;
122+
template class PaddingLoDTensorFunctor<platform::CPUDeviceContext, float>;
123+
template class PaddingLoDTensorFunctor<platform::CPUDeviceContext, double>;
124+
125+
template class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, int>;
126+
template class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, int64_t>;
127+
template class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, float>;
128+
template class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, double>;
126129

127130
} // namespace math
128131
} // namespace operators

paddle/fluid/operators/math/sequence_padding.cu

Lines changed: 72 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -21,74 +21,74 @@ namespace math {
2121

2222
template <typename T, bool Padding>
2323
__global__ void SequencePaddingKernel(
24-
T* padding_data, T* seq_data, const size_t* abs_offset,
25-
const size_t& seq_num, const size_t& max_seq_len, const size_t& seq_width,
26-
const PaddingLayout& padding_layout, bool norm_by_times = false,
27-
const T& padding_value = 0) {
28-
size_t padding_idx = blockIdx.y;
29-
size_t seq_start = abs_offset[padding_idx];
30-
size_t seq_len = abs_offset[padding_idx + 1] - seq_start;
24+
T* pad_data, T* seq_data, const size_t* seq_offset, const size_t& seq_num,
25+
const size_t& max_seq_len, const size_t& seq_width, bool norm_by_times,
26+
const T& pad_value, const OutputLayout& output_layout) {
27+
size_t seq_idx = blockIdx.y;
28+
size_t seq_start = seq_offset[seq_idx];
29+
size_t seq_len = seq_offset[seq_idx + 1] - seq_start;
3130

32-
size_t seq_idx = blockIdx.x * blockDim.y + threadIdx.y;
31+
size_t seq_step_idx = blockIdx.x * blockDim.y + threadIdx.y;
3332

34-
size_t seq_offset = (seq_start + seq_idx) * seq_width;
33+
size_t seq_data_offset = (seq_start + seq_step_idx) * seq_width;
3534

36-
size_t padding_offset = 0;
35+
size_t pad_data_offset = 0;
3736

38-
if (padding_layout == LENGTH_BATCH_WIDTH) {
39-
padding_offset = (seq_idx * seq_num + padding_idx) * seq_width;
37+
if (output_layout == kLengthBatchWidth) {
38+
pad_data_offset = (seq_step_idx * seq_num + seq_idx) * seq_width;
4039
} else {
41-
padding_offset = (padding_idx * max_seq_len + seq_idx) * seq_width;
40+
pad_data_offset = (seq_idx * max_seq_len + seq_step_idx) * seq_width;
4241
}
4342

44-
if (seq_idx < seq_len) {
43+
if (seq_step_idx < seq_len) {
4544
T scale = norm_by_times ? (1.0f / static_cast<T>(seq_len)) : 1.0f;
4645
if (Padding) {
47-
/* sequence -> padding */
46+
/* seq -> pad */
4847
for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) {
49-
padding_data[padding_offset + i] = scale * seq_data[seq_offset + i];
48+
pad_data[pad_data_offset + i] = scale * seq_data[seq_data_offset + i];
5049
}
5150
} else {
52-
/* padding -> sequence */
51+
/* pad -> seq */
5352
for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) {
54-
seq_data[seq_offset + i] = scale * padding_data[padding_offset + i];
53+
seq_data[seq_data_offset + i] = scale * pad_data[pad_data_offset + i];
5554
}
5655
}
57-
} else if (seq_idx < max_seq_len) {
56+
} else if (seq_step_idx < max_seq_len) {
5857
if (Padding) {
59-
/* sequence -> padding */
58+
/* seq -> pad */
6059
for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) {
61-
padding_data[padding_offset + i] = padding_value;
60+
pad_data[pad_data_offset + i] = pad_value;
6261
}
6362
}
6463
}
6564
}
6665

67-
template <typename T, PaddingLayout padding_layout>
68-
class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T, padding_layout> {
66+
template <typename T>
67+
class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
6968
public:
7069
void operator()(const platform::CUDADeviceContext& context,
7170
const framework::LoDTensor& seq_tensor,
72-
framework::Tensor* padding_tensor,
73-
T padding_value = static_cast<T>(0),
74-
bool norm_by_times = false, size_t lod_level = 0) {
75-
ValidateLoD(seq_tensor, lod_level);
71+
framework::Tensor* pad_tensor,
72+
T pad_value = static_cast<T>(0), bool norm_by_times = false,
73+
size_t lod_level = 0,
74+
OutputLayout output_layout = kBatchLengthWidth) {
75+
CheckLoD(seq_tensor, lod_level);
7676

7777
auto& lod = seq_tensor.lod();
78-
auto& abs_offset = framework::ToAbsOffset(lod)[lod_level];
78+
auto& seq_offset = framework::ToAbsOffset(lod)[lod_level];
7979

80-
auto seq_dims = seq_tensor.dims();
81-
auto padding_dims = padding_tensor->dims();
82-
int64_t max_seq_len = MaximumSequenceLength(lod, lod_level);
83-
const int64_t seq_num = abs_offset.size() - 1;
84-
const int64_t seq_width = seq_tensor.numel() / seq_dims[0];
80+
auto seq_tensor_dims = seq_tensor.dims();
81+
auto pad_tensor_dims = pad_tensor->dims();
82+
int64_t max_seq_len = MaximumSequenceLength(seq_offset);
83+
int64_t seq_num = seq_offset.size() - 1;
84+
int64_t seq_width = seq_tensor.numel() / seq_tensor_dims[0];
8585

86-
ValidateShape(seq_dims, abs_offset.back(), padding_dims, max_seq_len,
87-
seq_num, seq_width, padding_layout);
86+
CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len,
87+
seq_num, seq_width, output_layout);
8888

8989
if (!norm_by_times && seq_num == 1UL) {
90-
TensorCopy(seq_tensor, context.GetPlace(), context, padding_tensor);
91-
padding_tensor->Resize(padding_dims);
90+
TensorCopy(seq_tensor, context.GetPlace(), context, pad_tensor);
91+
pad_tensor->Resize(pad_tensor_dims);
9292
return;
9393
}
9494

@@ -107,37 +107,40 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T, padding_layout> {
107107
dim3 grid(grid_dim_x, grid_dim_y);
108108

109109
const T* seq_data = seq_tensor.data<T>();
110-
T* padding_data = padding_tensor->data<T>();
110+
T* pad_data = pad_tensor->data<T>();
111111

112112
SequencePaddingKernel<T, 1><<<grid, threads, 0, context.stream()>>>(
113-
padding_data, const_cast<T*>(seq_data),
114-
abs_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len,
115-
seq_width, padding_layout, norm_by_times, padding_value);
113+
pad_data, const_cast<T*>(seq_data),
114+
seq_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len,
115+
seq_width, norm_by_times, pad_value, output_layout);
116116
}
117117
};
118118

119-
template <typename T, PaddingLayout padding_layout>
120-
class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T,
121-
padding_layout> {
119+
template <typename T>
120+
class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
122121
public:
123122
void operator()(const platform::CUDADeviceContext& context,
124123
framework::LoDTensor* seq_tensor,
125-
const framework::Tensor& padding_tensor,
126-
bool norm_by_times = false, size_t lod_level = 0) {
127-
ValidateLoD(*seq_tensor, lod_level);
124+
const framework::Tensor& pad_tensor,
125+
bool norm_by_times = false, size_t lod_level = 0,
126+
OutputLayout output_layout = kBatchLengthWidth) {
127+
CheckLoD(*seq_tensor, lod_level);
128128

129129
auto& lod = seq_tensor->lod();
130-
auto& abs_offset = framework::ToAbsOffset(lod)[lod_level];
130+
auto& seq_offset = framework::ToAbsOffset(lod)[lod_level];
131131

132-
auto seq_dims = seq_tensor->dims();
133-
auto padding_dims = padding_tensor.dims();
134-
int64_t max_seq_len = MaximumSequenceLength(lod, lod_level);
135-
int64_t seq_num = abs_offset.size() - 1;
136-
int64_t seq_width = seq_tensor->numel() / seq_dims[0];
132+
auto seq_tensor_dims = seq_tensor->dims();
133+
auto pad_tensor_dims = pad_tensor.dims();
134+
int64_t max_seq_len = MaximumSequenceLength(seq_offset);
135+
int64_t seq_num = seq_offset.size() - 1;
136+
int64_t seq_width = seq_tensor->numel() / seq_tensor_dims[0];
137+
138+
CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len,
139+
seq_num, seq_width, output_layout);
137140

138141
if (!norm_by_times && seq_num == 1UL) {
139-
TensorCopy(padding_tensor, context.GetPlace(), context, seq_tensor);
140-
seq_tensor->Resize(seq_dims);
142+
TensorCopy(pad_tensor, context.GetPlace(), context, seq_tensor);
143+
seq_tensor->Resize(seq_tensor_dims);
141144
return;
142145
}
143146

@@ -155,20 +158,25 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T,
155158
size_t grid_dim_y = seq_num;
156159
dim3 grid(grid_dim_x, grid_dim_y);
157160

158-
const T* padding_data = padding_tensor.data<T>();
161+
const T* pad_data = pad_tensor.data<T>();
159162
T* seq_data = seq_tensor->data<T>();
160163

161-
SequencePaddingKernel<T, 1><<<grid, threads, 0, context.stream()>>>(
162-
const_cast<T*>(padding_data), seq_data,
163-
abs_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len,
164-
seq_width, padding_layout, norm_by_times);
164+
SequencePaddingKernel<T, 0><<<grid, threads, 0, context.stream()>>>(
165+
const_cast<T*>(pad_data), seq_data,
166+
seq_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len,
167+
seq_width, norm_by_times, static_cast<T>(0), output_layout);
165168
}
166169
};
167170

168-
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, float,
169-
LENGTH_BATCH_WIDTH>;
170-
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, float,
171-
LENGTH_BATCH_WIDTH>;
171+
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, int>;
172+
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, int64_t>;
173+
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;
174+
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, double>;
175+
176+
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, int>;
177+
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, int64_t>;
178+
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;
179+
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, double>;
172180

173181
} // namespace math
174182
} // namespace operators

0 commit comments

Comments
 (0)