Skip to content

Commit 5e7aa8c

Browse files
committed
code clean
1 parent 855c9e3 commit 5e7aa8c

File tree

4 files changed

+19
-17
lines changed

4 files changed

+19
-17
lines changed

paddle/fluid/framework/tensor_impl.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ inline T* Tensor::mutable_data(platform::Place place) {
5959
}
6060

6161
inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
62+
int rank = src.dims().size();
63+
PADDLE_ENFORCE_GE(
64+
rank, 2,
65+
"'ReshapeToMatrix()' is only used for flatten high rank "
66+
"tensors to matrixs. Can not be used in reshaping vectors.");
67+
if (rank == 2) {
68+
return src;
69+
}
6270
Tensor res;
6371
res.ShareDataWith(src);
6472
res.Resize(flatten_to_2d(src.dims(), num_col_dims));

paddle/fluid/operators/cross_entropy_op.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,9 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
4545
"Input(Label) should be 1.");
4646
}
4747

48-
auto out_dim_vec =
49-
framework::vectorize(framework::slice_ddim(x_dims, 0, rank - 1));
50-
out_dim_vec.push_back(1);
51-
52-
ctx->SetOutputDim("Y", framework::make_ddim(out_dim_vec));
48+
auto y_dims = x_dims;
49+
y_dims[rank - 1] = 1;
50+
ctx->SetOutputDim("Y", y_dims);
5351
ctx->ShareLoD("X", /*->*/ "Y");
5452
}
5553

paddle/fluid/operators/cross_entropy_op.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,9 @@ class CrossEntropyOpKernel : public framework::OpKernel<T> {
3434
y->mutable_data<T>(ctx.GetPlace());
3535

3636
int rank = x->dims().size();
37-
Tensor x_2d = rank > 2 ? framework::ReshapeToMatrix(*x, rank - 1) : *x;
38-
Tensor labels_2d =
39-
rank > 2 ? framework::ReshapeToMatrix(*labels, rank - 1) : *labels;
40-
Tensor y_2d = rank > 2 ? framework::ReshapeToMatrix(*y, rank - 1) : *y;
37+
Tensor x_2d = framework::ReshapeToMatrix(*x, rank - 1);
38+
Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
39+
Tensor y_2d = framework::ReshapeToMatrix(*y, rank - 1);
4140

4241
math::CrossEntropyFunctor<DeviceContext, T>()(
4342
ctx.template device_context<DeviceContext>(), &y_2d, &x_2d, &labels_2d,

paddle/fluid/operators/softmax_op.h

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,8 @@ class SoftmaxKernel : public framework::OpKernel<T> {
3232
Out->mutable_data<T>(context.GetPlace());
3333

3434
int rank = X->dims().size();
35-
Tensor X_2d = rank > 2 ? framework::ReshapeToMatrix(*X, rank - 1) : *X;
36-
Tensor Out_2d =
37-
rank > 2 ? framework::ReshapeToMatrix(*Out, rank - 1) : *Out;
35+
Tensor X_2d = framework::ReshapeToMatrix(*X, rank - 1);
36+
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
3837

3938
math::SoftmaxFunctor<DeviceContext, T>()(
4039
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
@@ -53,11 +52,9 @@ class SoftmaxGradKernel : public framework::OpKernel<T> {
5352
dX->mutable_data<T>(context.GetPlace());
5453

5554
int rank = Out->dims().size();
56-
Tensor Out_2d =
57-
rank > 2 ? framework::ReshapeToMatrix(*Out, rank - 1) : *Out;
58-
Tensor dOut_2d =
59-
rank > 2 ? framework::ReshapeToMatrix(*dOut, rank - 1) : *dOut;
60-
Tensor dX_2d = rank > 2 ? framework::ReshapeToMatrix(*dX, rank - 1) : *dX;
55+
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
56+
Tensor dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1);
57+
Tensor dX_2d = framework::ReshapeToMatrix(*dX, rank - 1);
6158

6259
math::SoftmaxGradFunctor<DeviceContext, T>()(
6360
context.template device_context<DeviceContext>(), &Out_2d, &dOut_2d,

0 commit comments

Comments
 (0)