Skip to content

Commit 1945b72

Browse files
authored
Fix CPPLint issues with math/sequence_padding (#10317)
* Fix cpplint issues in sequence_padding * Fix typo in cu file * Fix dependencies of sequence_padding * Add include
1 parent 9bcd9f6 commit 1945b72

File tree

5 files changed

+28
-26
lines changed

5 files changed

+28
-26
lines changed

paddle/fluid/operators/math/sequence_padding.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ template <typename T>
2222
class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
2323
public:
2424
void operator()(const platform::CPUDeviceContext& context,
25-
const framework::LoDTensor& seq, framework::Tensor& padding,
25+
const framework::LoDTensor& seq, framework::Tensor* padding,
2626
bool norm_by_times) {
2727
auto lod = seq.lod();
2828
PADDLE_ENFORCE_GT(lod.size(), 0UL,
@@ -37,7 +37,7 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
3737
"The first dimension of LoDTensor seq should be "
3838
"equal to the sum of all sequences's length.");
3939

40-
auto padding_dims = padding.dims();
40+
auto padding_dims = padding->dims();
4141
PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL,
4242
"The input padding should be a 3-D Tensor of shape "
4343
"[max_sequence_length, num_sequences, sequence_width].");
@@ -58,7 +58,7 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
5858
"width of sequence in LoDTensor seq.");
5959

6060
const T* seq_data = seq.data<T>();
61-
T* padding_data = padding.data<T>();
61+
T* padding_data = padding->data<T>();
6262
for (int64_t i = 0; i < max_sequence_length; ++i) {
6363
for (int64_t j = 0; j < num_sequences; ++j) {
6464
int64_t start_pos = abs_offset_lod[level][j];
@@ -84,16 +84,16 @@ template <typename T>
8484
class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
8585
public:
8686
void operator()(const platform::CPUDeviceContext& context,
87-
framework::LoDTensor& seq, const framework::Tensor& padding,
87+
framework::LoDTensor* seq, const framework::Tensor& padding,
8888
bool norm_by_times) {
89-
auto lod = seq.lod();
89+
auto lod = seq->lod();
9090
PADDLE_ENFORCE_GT(lod.size(), 0UL,
9191
"The LoD of LoDTensor seq should not be null.");
9292

9393
const size_t level = 0;
9494
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
9595

96-
auto seq_dims = seq.dims();
96+
auto seq_dims = seq->dims();
9797
PADDLE_ENFORCE_EQ(seq_dims[0],
9898
static_cast<int64_t>(abs_offset_lod[level].back()),
9999
"The first dimension of LoDTensor seq should be "
@@ -114,13 +114,13 @@ class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
114114
"The second dimension of Tensor padding should be "
115115
"the number of sequences in LoDTensor seq.");
116116

117-
const int64_t sequence_width = seq.numel() / seq_dims[0];
117+
const int64_t sequence_width = seq->numel() / seq_dims[0];
118118
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
119119
"The third dimension of Tensor padding should be the "
120120
"width of sequence in LoDTensor seq.");
121121

122122
const T* padding_data = padding.data<T>();
123-
T* seq_data = seq.data<T>();
123+
T* seq_data = seq->data<T>();
124124
for (int64_t i = 0; i < num_sequences; ++i) {
125125
int64_t start_pos = abs_offset_lod[level][i];
126126
int64_t sequence_length = abs_offset_lod[level][i + 1] - start_pos;

paddle/fluid/operators/math/sequence_padding.cu

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include <algorithm>
1516
#include "paddle/fluid/operators/math/sequence_padding.h"
1617

1718
namespace paddle {
@@ -61,7 +62,7 @@ template <typename T>
6162
class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
6263
public:
6364
void operator()(const platform::CUDADeviceContext& context,
64-
const framework::LoDTensor& seq, framework::Tensor& padding,
65+
const framework::LoDTensor& seq, framework::Tensor* padding,
6566
bool norm_by_times) {
6667
auto lod = seq.lod();
6768
PADDLE_ENFORCE_GT(lod.size(), 0UL,
@@ -76,7 +77,7 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
7677
"The first dimension of LoDTensor seq should be "
7778
"equal to the sum of all sequences's length.");
7879

79-
auto padding_dims = padding.dims();
80+
auto padding_dims = padding->dims();
8081
PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL,
8182
"The input padding should be a 3-D Tensor of shape "
8283
"[max_sequence_length, num_sequences, sequence_width].");
@@ -97,8 +98,8 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
9798
"width of sequence in LoDTensor seq.");
9899

99100
if (!norm_by_times && num_sequences == 1UL) {
100-
TensorCopy(seq, context.GetPlace(), context, &padding);
101-
padding.Resize(padding_dims);
101+
TensorCopy(seq, context.GetPlace(), context, padding);
102+
padding->Resize(padding_dims);
102103
return;
103104
}
104105

@@ -117,7 +118,7 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
117118
dim3 grid(grid_dim_x, grid_dim_y);
118119

119120
const T* seq_data = seq.data<T>();
120-
T* padding_data = padding.data<T>();
121+
T* padding_data = padding->data<T>();
121122
if (norm_by_times) {
122123
SequencePaddingKernel<T, 1, 1><<<grid, threads, 0, context.stream()>>>(
123124
padding_data, const_cast<T*>(seq_data),
@@ -136,16 +137,16 @@ template <typename T>
136137
class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
137138
public:
138139
void operator()(const platform::CUDADeviceContext& context,
139-
framework::LoDTensor& seq, const framework::Tensor& padding,
140+
framework::LoDTensor* seq, const framework::Tensor& padding,
140141
bool norm_by_times) {
141-
auto lod = seq.lod();
142+
auto lod = seq->lod();
142143
PADDLE_ENFORCE_GT(lod.size(), 0UL,
143144
"The lod of LoDTensor seq should not be null.");
144145

145146
const size_t level = 0;
146147
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
147148

148-
auto seq_dims = seq.dims();
149+
auto seq_dims = seq->dims();
149150
PADDLE_ENFORCE_EQ(seq_dims[0],
150151
static_cast<int64_t>(abs_offset_lod[level].back()),
151152
"The first dimension of LoDTensor seq should be "
@@ -166,14 +167,14 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
166167
"The second dimension of Tensor padding should be "
167168
"the number of sequences in LoDTensor seq.");
168169

169-
const int64_t sequence_width = seq.numel() / seq_dims[0];
170+
const int64_t sequence_width = seq->numel() / seq_dims[0];
170171
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
171172
"The third dimension of Tensor padding should be the "
172173
"width of sequence in LoDTensor seq.");
173174

174175
if (!norm_by_times && num_sequences == 1UL) {
175-
TensorCopy(padding, context.GetPlace(), context, &seq);
176-
seq.Resize(seq_dims);
176+
TensorCopy(padding, context.GetPlace(), context, seq);
177+
seq->Resize(seq_dims);
177178
return;
178179
}
179180

@@ -192,7 +193,7 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
192193
dim3 grid(grid_dim_x, grid_dim_y);
193194

194195
const T* padding_data = padding.data<T>();
195-
T* seq_data = seq.data<T>();
196+
T* seq_data = seq->data<T>();
196197
if (norm_by_times) {
197198
SequencePaddingKernel<T, 1, 0><<<grid, threads, 0, context.stream()>>>(
198199
const_cast<T*>(padding_data), seq_data,

paddle/fluid/operators/math/sequence_padding.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include <algorithm>
1718
#include "paddle/fluid/framework/lod_tensor.h"
1819
#include "paddle/fluid/platform/device_context.h"
1920

@@ -64,13 +65,13 @@ template <typename DeviceContext, typename T>
6465
class PaddingLoDTensorFunctor {
6566
public:
6667
void operator()(const DeviceContext& context, const framework::LoDTensor& seq,
67-
framework::Tensor& padding, bool norm_by_times);
68+
framework::Tensor* padding, bool norm_by_times);
6869
};
6970

7071
template <typename DeviceContext, typename T>
7172
class UnpaddingLoDTensorFunctor {
7273
public:
73-
void operator()(const DeviceContext& context, framework::LoDTensor& seq,
74+
void operator()(const DeviceContext& context, framework::LoDTensor* seq,
7475
const framework::Tensor& padding, bool norm_by_times);
7576
};
7677

paddle/fluid/operators/math/sequence_padding_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
5454
static_cast<int64_t>(sequence_width)});
5555
padding.mutable_data<T>(padding_dims, *place);
5656
paddle::operators::math::PaddingLoDTensorFunctor<DeviceContext, T>()(
57-
*context, seq, padding, false);
57+
*context, seq, &padding, false);
5858

5959
seq_back.set_lod(lod);
6060
seq_back.mutable_data<T>(seq_dims, *place);
6161
paddle::operators::math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
62-
*context, seq_back, padding, false);
62+
*context, &seq_back, padding, false);
6363

6464
if (paddle::platform::is_cpu_place(*place)) {
6565
cpu_seq_back = seq_back;

paddle/fluid/operators/warpctc_op.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ class WarpCTCKernel : public framework::OpKernel<T> {
162162
static_cast<int64_t>(sequence_width)});
163163
warpctc_logits.mutable_data<T>(warpctc_logits_dims, ctx.GetPlace());
164164
math::PaddingLoDTensorFunctor<DeviceContext, T>()(
165-
ctx.template device_context<DeviceContext>(), *logits, warpctc_logits,
165+
ctx.template device_context<DeviceContext>(), *logits, &warpctc_logits,
166166
false);
167167
const T* warpctc_logits_data = warpctc_logits.data<T>();
168168

@@ -217,7 +217,7 @@ class WarpCTCGradKernel : public framework::OpKernel<T> {
217217
logits_grad->mutable_data<T>(ctx.GetPlace());
218218
bool norm_by_times = ctx.Attr<bool>("norm_by_times");
219219
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
220-
ctx.template device_context<DeviceContext>(), *logits_grad,
220+
ctx.template device_context<DeviceContext>(), logits_grad,
221221
*warpctc_grad, norm_by_times);
222222

223223
const T* loss_grad_data = loss_grad->data<T>();

0 commit comments

Comments
 (0)