@@ -28,23 +28,28 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
28
28
29
29
auto x_dims = ctx->GetInputDim (" X" );
30
30
auto label_dims = ctx->GetInputDim (" Label" );
31
- PADDLE_ENFORCE_EQ (x_dims.size (), 2UL , " Input(X)'s rank should be 2." );
32
- PADDLE_ENFORCE_EQ (label_dims.size (), 2UL ,
33
- " Input(Label)'s rank should be 2." );
34
- PADDLE_ENFORCE_EQ (x_dims[0 ], label_dims[0 ],
35
- " The 1st dimension of Input(X) and Input(Label) should "
36
- " be equal." );
31
+ int rank = x_dims.size ();
32
+ PADDLE_ENFORCE_EQ (rank, label_dims.size (),
33
+ " Input(X) and Input(Label) shall have the same rank." );
34
+ PADDLE_ENFORCE_EQ (framework::slice_ddim (x_dims, 0 , rank - 1 ),
35
+ framework::slice_ddim (label_dims, 0 , rank - 1 ),
36
+ " Input(X) and Input(Label) shall have the same shape "
37
+ " except the last dimension." );
37
38
if (ctx->Attrs ().Get <bool >(" soft_label" )) {
38
- PADDLE_ENFORCE_EQ (x_dims[1 ], label_dims[1 ],
39
- " If Attr(soft_label) == true, the 2nd dimension of "
39
+ PADDLE_ENFORCE_EQ (x_dims[rank - 1 ], label_dims[rank - 1 ],
40
+ " If Attr(soft_label) == true, the last dimension of "
40
41
" Input(X) and Input(Label) should be equal." );
41
42
} else {
42
- PADDLE_ENFORCE_EQ (label_dims[1 ], 1UL ,
43
- " If Attr(softLabel) == false, the 2nd dimension of "
43
+ PADDLE_ENFORCE_EQ (label_dims[rank - 1 ], 1UL ,
44
+ " If Attr(softLabel) == false, the last dimension of "
44
45
" Input(Label) should be 1." );
45
46
}
46
47
47
- ctx->SetOutputDim (" Y" , {x_dims[0 ], 1 });
48
+ auto out_dim_vec =
49
+ framework::vectorize (framework::slice_ddim (x_dims, 0 , rank - 1 ));
50
+ out_dim_vec.push_back (1 );
51
+
52
+ ctx->SetOutputDim (" Y" , framework::make_ddim (out_dim_vec));
48
53
ctx->ShareLoD (" X" , /* ->*/ " Y" );
49
54
}
50
55
@@ -74,24 +79,28 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
74
79
auto x_dims = ctx->GetInputDim (" X" );
75
80
auto label_dims = ctx->GetInputDim (" Label" );
76
81
auto dy_dims = ctx->GetInputDim (framework::GradVarName (" Y" ));
77
- PADDLE_ENFORCE_EQ (x_dims.size (), 2 , " Input(X)'s rank should be 2." );
78
- PADDLE_ENFORCE_EQ (dy_dims.size (), 2 , " Input(Y@Grad)'s rank should be 2." );
79
- PADDLE_ENFORCE_EQ (label_dims.size (), 2 , " Input(Label)'s rank should be 2." );
80
- PADDLE_ENFORCE_EQ (x_dims[0 ], label_dims[0 ],
81
- " The 1st dimension of Input(X) and Input(Label) should "
82
- " be equal." );
83
- PADDLE_ENFORCE_EQ (x_dims[0 ], dy_dims[0 ],
84
- " The 1st dimension of Input(X) and Input(Y@Grad) should "
85
- " be equal." );
86
- PADDLE_ENFORCE_EQ (dy_dims[1 ], 1 ,
87
- " The 2nd dimension of Input(Y@Grad) should be 1." );
82
+ int rank = x_dims.size ();
83
+ PADDLE_ENFORCE_EQ (dy_dims.size (), rank,
84
+ " Input(Y@Grad) and Input(X) should have the same rank." );
85
+ PADDLE_ENFORCE_EQ (label_dims.size (), rank,
86
+ " Input(Label) and Input(X) should have the same rank." );
87
+ PADDLE_ENFORCE_EQ (framework::slice_ddim (x_dims, 0 , rank - 1 ),
88
+ framework::slice_ddim (label_dims, 0 , rank - 1 ),
89
+ " The Input(X) and Input(Label) should have the same "
90
+ " shape except the last dimension." );
91
+ PADDLE_ENFORCE_EQ (framework::slice_ddim (x_dims, 0 , rank - 1 ),
92
+ framework::slice_ddim (dy_dims, 0 , rank - 1 ),
93
+ " The Input(X) and Input(Y@Grad) should have the same "
94
+ " shape except the last dimension." );
95
+ PADDLE_ENFORCE_EQ (dy_dims[rank - 1 ], 1 ,
96
+ " The last dimension of Input(Y@Grad) should be 1." );
88
97
if (ctx->Attrs ().Get <bool >(" soft_label" )) {
89
- PADDLE_ENFORCE_EQ (x_dims[1 ], label_dims[1 ],
90
- " When Attr(soft_label) == true, the 2nd dimension of "
98
+ PADDLE_ENFORCE_EQ (x_dims[rank - 1 ], label_dims[rank - 1 ],
99
+ " When Attr(soft_label) == true, the last dimension of "
91
100
" Input(X) and Input(Label) should be equal." );
92
101
} else {
93
- PADDLE_ENFORCE_EQ (label_dims[1 ], 1 ,
94
- " When Attr(soft_label) == false, the 2nd dimension of "
102
+ PADDLE_ENFORCE_EQ (label_dims[rank - 1 ], 1 ,
103
+ " When Attr(soft_label) == false, the last dimension of "
95
104
" Input(Label) should be 1." );
96
105
}
97
106
ctx->SetOutputDim (framework::GradVarName (" X" ), x_dims);
@@ -113,25 +122,33 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
113
122
public:
114
123
void Make () override {
115
124
AddInput (" X" ,
116
- " (Tensor, default Tensor<float>), a 2-D tensor with shape [N x D],"
117
- " where N is the batch size and D is the number of classes. "
118
- " This input is a probability computed by the previous operator, "
119
- " which is almost always the result of a softmax operator." );
120
- AddInput (" Label" ,
121
- " (Tensor), the ground truth which is a 2-D tensor. When "
122
- " soft_label is set to false, Label is a Tensor<int64> with shape "
123
- " [N x 1]. When soft_label is set to true, Label is a "
124
- " Tensor<float/double> with shape [N x D]." );
125
+ " (Tensor, default Tensor<float>), a tensor whose last dimension "
126
+ " size is equal to the number of classes. This input is a "
127
+ " probability computed by the previous operator, which is almost "
128
+ " always the result of a softmax operator." );
129
+ AddInput (
130
+ " Label" ,
131
+ " (Tensor), the tensor which represents the ground truth. It has the "
132
+ " same shape with 'X' except the last dimension. When soft_label is set "
133
+ " to false, the last dimension size is 1; when soft_label is set to "
134
+ " true, the last dimension size is equal to the number of classes." );
125
135
AddOutput (" Y" ,
126
- " (Tensor, default Tensor<float>), a 2-D tensor with shape "
127
- " [N x 1]. The cross entropy loss." );
136
+ " (Tensor, default Tensor<float>), a tensor whose shape is same "
137
+ " with 'X' except that the last dimension size is 1. It "
138
+ " represents the cross entropy loss." );
128
139
AddAttr<bool >(" soft_label" ,
129
140
" (bool, default false), a flag indicating whether to "
130
141
" interpretate the given labels as soft labels." )
131
142
.SetDefault (false );
132
143
AddComment (R"DOC(
133
144
CrossEntropy Operator.
134
145
146
+ The input 'X' and 'Label' will first be logically flattened to 2-D matrixs.
147
+ The matrix's second dimension(row length) is as same as the original last
148
+ dimension, and the first dimension(column length) is the product of all other
149
+ original dimensions. Then the softmax computation will take palce on each raw
150
+ of flattened matrixs.
151
+
135
152
It supports both standard cross-entropy and soft-label cross-entropy loss
136
153
computation.
137
154
1) One-hot cross-entropy:
0 commit comments