Skip to content

Commit 7c55e08

Browse files
committed
stash
1 parent b656d97 commit 7c55e08

File tree

1 file changed

+54
-37
lines changed

1 file changed

+54
-37
lines changed

paddle/fluid/operators/cross_entropy_op.cc

Lines changed: 54 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,28 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
2828

2929
auto x_dims = ctx->GetInputDim("X");
3030
auto label_dims = ctx->GetInputDim("Label");
31-
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input(X)'s rank should be 2.");
32-
PADDLE_ENFORCE_EQ(label_dims.size(), 2UL,
33-
"Input(Label)'s rank should be 2.");
34-
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
35-
"The 1st dimension of Input(X) and Input(Label) should "
36-
"be equal.");
31+
int rank = x_dims.size();
32+
PADDLE_ENFORCE_EQ(rank, label_dims.size(),
33+
"Input(X) and Input(Label) shall have the same rank.");
34+
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
35+
framework::slice_ddim(label_dims, 0, rank - 1),
36+
"Input(X) and Input(Label) shall have the same shape "
37+
"except the last dimension.");
3738
if (ctx->Attrs().Get<bool>("soft_label")) {
38-
PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1],
39-
"If Attr(soft_label) == true, the 2nd dimension of "
39+
PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1],
40+
"If Attr(soft_label) == true, the last dimension of "
4041
"Input(X) and Input(Label) should be equal.");
4142
} else {
42-
PADDLE_ENFORCE_EQ(label_dims[1], 1UL,
43-
"If Attr(softLabel) == false, the 2nd dimension of "
43+
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL,
44+
"If Attr(softLabel) == false, the last dimension of "
4445
"Input(Label) should be 1.");
4546
}
4647

47-
ctx->SetOutputDim("Y", {x_dims[0], 1});
48+
auto out_dim_vec =
49+
framework::vectorize(framework::slice_ddim(x_dims, 0, rank - 1));
50+
out_dim_vec.push_back(1);
51+
52+
ctx->SetOutputDim("Y", framework::make_ddim(out_dim_vec));
4853
ctx->ShareLoD("X", /*->*/ "Y");
4954
}
5055

@@ -74,24 +79,28 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
7479
auto x_dims = ctx->GetInputDim("X");
7580
auto label_dims = ctx->GetInputDim("Label");
7681
auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y"));
77-
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
78-
PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2.");
79-
PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2.");
80-
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
81-
"The 1st dimension of Input(X) and Input(Label) should "
82-
"be equal.");
83-
PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0],
84-
"The 1st dimension of Input(X) and Input(Y@Grad) should "
85-
"be equal.");
86-
PADDLE_ENFORCE_EQ(dy_dims[1], 1,
87-
"The 2nd dimension of Input(Y@Grad) should be 1.");
82+
int rank = x_dims.size();
83+
PADDLE_ENFORCE_EQ(dy_dims.size(), rank,
84+
"Input(Y@Grad) and Input(X) should have the same rank.");
85+
PADDLE_ENFORCE_EQ(label_dims.size(), rank,
86+
"Input(Label) and Input(X) should have the same rank.");
87+
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
88+
framework::slice_ddim(label_dims, 0, rank - 1),
89+
"The Input(X) and Input(Label) should have the same "
90+
"shape except the last dimension.");
91+
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
92+
framework::slice_ddim(dy_dims, 0, rank - 1),
93+
"The Input(X) and Input(Y@Grad) should have the same "
94+
"shape except the last dimension.");
95+
PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1,
96+
"The last dimension of Input(Y@Grad) should be 1.");
8897
if (ctx->Attrs().Get<bool>("soft_label")) {
89-
PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1],
90-
"When Attr(soft_label) == true, the 2nd dimension of "
98+
PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1],
99+
"When Attr(soft_label) == true, the last dimension of "
91100
"Input(X) and Input(Label) should be equal.");
92101
} else {
93-
PADDLE_ENFORCE_EQ(label_dims[1], 1,
94-
"When Attr(soft_label) == false, the 2nd dimension of "
102+
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1,
103+
"When Attr(soft_label) == false, the last dimension of "
95104
"Input(Label) should be 1.");
96105
}
97106
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
@@ -113,25 +122,33 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
113122
public:
114123
void Make() override {
115124
AddInput("X",
116-
"(Tensor, default Tensor<float>), a 2-D tensor with shape [N x D],"
117-
" where N is the batch size and D is the number of classes. "
118-
"This input is a probability computed by the previous operator, "
119-
"which is almost always the result of a softmax operator.");
120-
AddInput("Label",
121-
"(Tensor), the ground truth which is a 2-D tensor. When "
122-
"soft_label is set to false, Label is a Tensor<int64> with shape "
123-
"[N x 1]. When soft_label is set to true, Label is a "
124-
"Tensor<float/double> with shape [N x D].");
125+
"(Tensor, default Tensor<float>), a tensor whose last dimension "
126+
"size is equal to the number of classes. This input is a "
127+
"probability computed by the previous operator, which is almost "
128+
"always the result of a softmax operator.");
129+
AddInput(
130+
"Label",
131+
"(Tensor), the tensor which represents the ground truth. It has the "
132+
"same shape with 'X' except the last dimension. When soft_label is set "
133+
"to false, the last dimension size is 1; when soft_label is set to "
134+
"true, the last dimension size is equal to the number of classes.");
125135
AddOutput("Y",
126-
"(Tensor, default Tensor<float>), a 2-D tensor with shape "
127-
"[N x 1]. The cross entropy loss.");
136+
"(Tensor, default Tensor<float>), a tensor whose shape is same "
137+
"with 'X' except that the last dimension size is 1. It "
138+
"represents the cross entropy loss.");
128139
AddAttr<bool>("soft_label",
129140
"(bool, default false), a flag indicating whether to "
130141
"interpretate the given labels as soft labels.")
131142
.SetDefault(false);
132143
AddComment(R"DOC(
133144
CrossEntropy Operator.
134145
146+
The input 'X' and 'Label' will first be logically flattened to 2-D matrixs.
147+
The matrix's second dimension(row length) is as same as the original last
148+
dimension, and the first dimension(column length) is the product of all other
149+
original dimensions. Then the softmax computation will take palce on each raw
150+
of flattened matrixs.
151+
135152
It supports both standard cross-entropy and soft-label cross-entropy loss
136153
computation.
137154
1) One-hot cross-entropy:

0 commit comments

Comments
 (0)