@@ -28,23 +28,26 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
28
28
29
29
auto x_dims = ctx->GetInputDim (" X" );
30
30
auto label_dims = ctx->GetInputDim (" Label" );
31
- PADDLE_ENFORCE_EQ (x_dims.size (), 2UL , " Input(X)'s rank should be 2." );
32
- PADDLE_ENFORCE_EQ (label_dims.size (), 2UL ,
33
- " Input(Label)'s rank should be 2." );
34
- PADDLE_ENFORCE_EQ (x_dims[0 ], label_dims[0 ],
35
- " The 1st dimension of Input(X) and Input(Label) should "
36
- " be equal." );
31
+ int rank = x_dims.size ();
32
+ PADDLE_ENFORCE_EQ (rank, label_dims.size (),
33
+ " Input(X) and Input(Label) shall have the same rank." );
34
+ PADDLE_ENFORCE_EQ (framework::slice_ddim (x_dims, 0 , rank - 1 ),
35
+ framework::slice_ddim (label_dims, 0 , rank - 1 ),
36
+ " Input(X) and Input(Label) shall have the same shape "
37
+ " except the last dimension." );
37
38
if (ctx->Attrs ().Get <bool >(" soft_label" )) {
38
- PADDLE_ENFORCE_EQ (x_dims[1 ], label_dims[1 ],
39
- " If Attr(soft_label) == true, the 2nd dimension of "
39
+ PADDLE_ENFORCE_EQ (x_dims[rank - 1 ], label_dims[rank - 1 ],
40
+ " If Attr(soft_label) == true, the last dimension of "
40
41
" Input(X) and Input(Label) should be equal." );
41
42
} else {
42
- PADDLE_ENFORCE_EQ (label_dims[1 ], 1UL ,
43
- " If Attr(softLabel) == false, the 2nd dimension of "
43
+ PADDLE_ENFORCE_EQ (label_dims[rank - 1 ], 1UL ,
44
+ " If Attr(softLabel) == false, the last dimension of "
44
45
" Input(Label) should be 1." );
45
46
}
46
47
47
- ctx->SetOutputDim (" Y" , {x_dims[0 ], 1 });
48
+ auto y_dims = x_dims;
49
+ y_dims[rank - 1 ] = 1 ;
50
+ ctx->SetOutputDim (" Y" , y_dims);
48
51
ctx->ShareLoD (" X" , /* ->*/ " Y" );
49
52
}
50
53
@@ -74,24 +77,28 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
74
77
auto x_dims = ctx->GetInputDim (" X" );
75
78
auto label_dims = ctx->GetInputDim (" Label" );
76
79
auto dy_dims = ctx->GetInputDim (framework::GradVarName (" Y" ));
77
- PADDLE_ENFORCE_EQ (x_dims.size (), 2 , " Input(X)'s rank should be 2." );
78
- PADDLE_ENFORCE_EQ (dy_dims.size (), 2 , " Input(Y@Grad)'s rank should be 2." );
79
- PADDLE_ENFORCE_EQ (label_dims.size (), 2 , " Input(Label)'s rank should be 2." );
80
- PADDLE_ENFORCE_EQ (x_dims[0 ], label_dims[0 ],
81
- " The 1st dimension of Input(X) and Input(Label) should "
82
- " be equal." );
83
- PADDLE_ENFORCE_EQ (x_dims[0 ], dy_dims[0 ],
84
- " The 1st dimension of Input(X) and Input(Y@Grad) should "
85
- " be equal." );
86
- PADDLE_ENFORCE_EQ (dy_dims[1 ], 1 ,
87
- " The 2nd dimension of Input(Y@Grad) should be 1." );
80
+ int rank = x_dims.size ();
81
+ PADDLE_ENFORCE_EQ (dy_dims.size (), rank,
82
+ " Input(Y@Grad) and Input(X) should have the same rank." );
83
+ PADDLE_ENFORCE_EQ (label_dims.size (), rank,
84
+ " Input(Label) and Input(X) should have the same rank." );
85
+ PADDLE_ENFORCE_EQ (framework::slice_ddim (x_dims, 0 , rank - 1 ),
86
+ framework::slice_ddim (label_dims, 0 , rank - 1 ),
87
+ " The Input(X) and Input(Label) should have the same "
88
+ " shape except the last dimension." );
89
+ PADDLE_ENFORCE_EQ (framework::slice_ddim (x_dims, 0 , rank - 1 ),
90
+ framework::slice_ddim (dy_dims, 0 , rank - 1 ),
91
+ " The Input(X) and Input(Y@Grad) should have the same "
92
+ " shape except the last dimension." );
93
+ PADDLE_ENFORCE_EQ (dy_dims[rank - 1 ], 1 ,
94
+ " The last dimension of Input(Y@Grad) should be 1." );
88
95
if (ctx->Attrs ().Get <bool >(" soft_label" )) {
89
- PADDLE_ENFORCE_EQ (x_dims[1 ], label_dims[1 ],
90
- " When Attr(soft_label) == true, the 2nd dimension of "
96
+ PADDLE_ENFORCE_EQ (x_dims[rank - 1 ], label_dims[rank - 1 ],
97
+ " When Attr(soft_label) == true, the last dimension of "
91
98
" Input(X) and Input(Label) should be equal." );
92
99
} else {
93
- PADDLE_ENFORCE_EQ (label_dims[1 ], 1 ,
94
- " When Attr(soft_label) == false, the 2nd dimension of "
100
+ PADDLE_ENFORCE_EQ (label_dims[rank - 1 ], 1 ,
101
+ " When Attr(soft_label) == false, the last dimension of "
95
102
" Input(Label) should be 1." );
96
103
}
97
104
ctx->SetOutputDim (framework::GradVarName (" X" ), x_dims);
@@ -113,25 +120,33 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
113
120
public:
114
121
void Make () override {
115
122
AddInput (" X" ,
116
- " (Tensor, default Tensor<float>), a 2-D tensor with shape [N x D],"
117
- " where N is the batch size and D is the number of classes. "
118
- " This input is a probability computed by the previous operator, "
119
- " which is almost always the result of a softmax operator." );
120
- AddInput (" Label" ,
121
- " (Tensor), the ground truth which is a 2-D tensor. When "
122
- " soft_label is set to false, Label is a Tensor<int64> with shape "
123
- " [N x 1]. When soft_label is set to true, Label is a "
124
- " Tensor<float/double> with shape [N x D]." );
123
+ " (Tensor, default Tensor<float>), a tensor whose last dimension "
124
+ " size is equal to the number of classes. This input is a "
125
+ " probability computed by the previous operator, which is almost "
126
+ " always the result of a softmax operator." );
127
+ AddInput (
128
+ " Label" ,
129
+ " (Tensor), the tensor which represents the ground truth. It has the "
130
+ " same shape with 'X' except the last dimension. When soft_label is set "
131
+ " to false, the last dimension size is 1; when soft_label is set to "
132
+ " true, the last dimension size is equal to the number of classes." );
125
133
AddOutput (" Y" ,
126
- " (Tensor, default Tensor<float>), a 2-D tensor with shape "
127
- " [N x 1]. The cross entropy loss." );
134
+ " (Tensor, default Tensor<float>), a tensor whose shape is same "
135
+ " with 'X' except that the last dimension size is 1. It "
136
+ " represents the cross entropy loss." );
128
137
AddAttr<bool >(" soft_label" ,
129
138
" (bool, default false), a flag indicating whether to "
130
139
" interpretate the given labels as soft labels." )
131
140
.SetDefault (false );
132
141
AddComment (R"DOC(
133
142
CrossEntropy Operator.
134
143
144
+ The input 'X' and 'Label' will first be logically flattened to 2-D matrixs.
145
+ The matrix's second dimension(row length) is as same as the original last
146
+ dimension, and the first dimension(column length) is the product of all other
147
+ original dimensions. Then the softmax computation will take palce on each raw
148
+ of flattened matrixs.
149
+
135
150
It supports both standard cross-entropy and soft-label cross-entropy loss
136
151
computation.
137
152
1) One-hot cross-entropy:
0 commit comments