@@ -13,19 +13,21 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#include " paddle/fluid/operators/cross_entropy_op.h"
16
+ #include < memory>
16
17
#include < string>
17
18
#include < unordered_map>
18
19
19
20
namespace paddle {
20
21
namespace operators {
21
22
22
- class CrossEntropyOp : public framework ::OperatorWithKernel {
23
+ class CrossEntropyOpBase : public framework ::OperatorWithKernel {
23
24
public:
24
25
using framework::OperatorWithKernel::OperatorWithKernel;
25
26
26
27
void InferShape (framework::InferShapeContext* ctx) const override {
27
28
PADDLE_ENFORCE (ctx->HasInput (" X" ), " Input(X) should be not null." );
28
29
PADDLE_ENFORCE (ctx->HasInput (" Label" ), " Input(Label) should be not null." );
30
+
29
31
PADDLE_ENFORCE (ctx->HasOutput (" Y" ), " Output(Y) should be not null." );
30
32
31
33
auto x_dims = ctx->GetInputDim (" X" );
@@ -44,7 +46,8 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
44
46
" Input(X) and Input(Label) shall have the same shape "
45
47
" except the last dimension." );
46
48
}
47
- if (ctx->Attrs ().Get <bool >(" soft_label" )) {
49
+
50
+ if (IsSoftLabel (ctx)) {
48
51
if (check) {
49
52
PADDLE_ENFORCE_EQ (x_dims[rank - 1 ], label_dims[rank - 1 ],
50
53
" If Attr(soft_label) == true, the last dimension of "
@@ -70,21 +73,24 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
70
73
return framework::OpKernelType (ctx.Input <Tensor>(" X" )->type (),
71
74
ctx.device_context ());
72
75
}
76
+
77
+ virtual bool IsSoftLabel (framework::InferShapeContext* ctx) const {
78
+ return ctx->Attrs ().Get <bool >(" soft_label" );
79
+ }
73
80
};
74
81
75
- class CrossEntropyGradientOp : public framework ::OperatorWithKernel {
82
+ class CrossEntropyGradientOpBase : public framework ::OperatorWithKernel {
76
83
public:
77
84
using framework::OperatorWithKernel::OperatorWithKernel;
78
85
79
- void InferShape (framework::InferShapeContext* ctx) const override {
80
- PADDLE_ENFORCE (ctx->HasInput (" X" ), " Input(X) should be not null." );
86
+ void InferShape (framework::InferShapeContext* ctx) const {
81
87
PADDLE_ENFORCE (ctx->HasInput (" Label" ), " Input(Label) should be not null." );
82
88
PADDLE_ENFORCE (ctx->HasInput (framework::GradVarName (" Y" )),
83
89
" Input(Y@GRAD) shoudl be not null." );
84
90
PADDLE_ENFORCE (ctx->HasOutput (framework::GradVarName (" X" )),
85
91
" Output(X@GRAD) should be not null." );
86
92
87
- auto x_dims = ctx-> GetInputDim ( " X " );
93
+ auto x_dims = GetXDim (ctx );
88
94
auto label_dims = ctx->GetInputDim (" Label" );
89
95
auto dy_dims = ctx->GetInputDim (framework::GradVarName (" Y" ));
90
96
int rank = x_dims.size ();
@@ -109,9 +115,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
109
115
" The Input(X) and Input(Y@Grad) should have the same "
110
116
" shape except the last dimension." );
111
117
}
112
- PADDLE_ENFORCE_EQ (dy_dims[rank - 1 ], 1 ,
113
- " The last dimension of Input(Y@Grad) should be 1." );
114
- if (ctx->Attrs ().Get <bool >(" soft_label" )) {
118
+ if (IsSoftLabel (ctx)) {
115
119
if (check) {
116
120
PADDLE_ENFORCE_EQ (
117
121
x_dims[rank - 1 ], label_dims[rank - 1 ],
@@ -124,16 +128,39 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
124
128
" Input(Label) should be 1." );
125
129
}
126
130
ctx->SetOutputDim (framework::GradVarName (" X" ), x_dims);
127
- ctx->ShareLoD (" X" , framework::GradVarName (" X" ));
131
+ PADDLE_ENFORCE_EQ (dy_dims[rank - 1 ], 1 ,
132
+ " The last dimension of Input(Y@Grad) should be 1." );
133
+ ctx->SetOutputDim (framework::GradVarName (" X" ), x_dims);
134
+ ctx->ShareLoD (VarNameWithXLoD (), framework::GradVarName (" X" ));
128
135
}
129
136
130
137
protected:
131
138
// Explicitly set that the data type of computation kernel of cross_entropy
132
139
// is determined by its input "X".
133
140
framework::OpKernelType GetExpectedKernelType (
134
141
const framework::ExecutionContext& ctx) const override {
135
- return framework::OpKernelType (ctx.Input <Tensor>(" X" )->type (),
136
- ctx.device_context ());
142
+ return framework::OpKernelType (
143
+ ctx.Input <Tensor>(framework::GradVarName (" Y" ))->type (),
144
+ ctx.device_context ());
145
+ }
146
+
147
+ virtual framework::DDim GetXDim (framework::InferShapeContext* ctx) const {
148
+ return ctx->GetInputDim (" X" );
149
+ }
150
+
151
+ virtual const char * VarNameWithXLoD () const { return " X" ; }
152
+
153
+ virtual bool IsSoftLabel (framework::InferShapeContext* ctx) const {
154
+ return ctx->Attrs ().Get <bool >(" soft_label" );
155
+ }
156
+ };
157
+
158
+ class CrossEntropyOpInferVarType
159
+ : public framework::PassInDtypeAndVarTypeToOutput {
160
+ protected:
161
+ std::unordered_map<std::string, std::string> GetInputOutputWithSameType ()
162
+ const override {
163
+ return std::unordered_map<std::string, std::string>{{" X" , /* ->*/ " Y" }};
137
164
}
138
165
};
139
166
@@ -201,26 +228,147 @@ or not. But the output only shares the LoD information with input X.
201
228
}
202
229
};
203
230
204
- class CrossEntropyOpInferVarType
205
- : public framework::PassInDtypeAndVarTypeToOutput {
231
+ class CrossEntropyGradientOp : public CrossEntropyGradientOpBase {
232
+ public:
233
+ using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase;
234
+
235
+ void InferShape (framework::InferShapeContext* ctx) const override {
236
+ PADDLE_ENFORCE (ctx->HasInput (" X" ), " Input(X) should be not null." );
237
+ CrossEntropyGradientOpBase::InferShape (ctx);
238
+ }
239
+ };
240
+
241
+ class CrossEntropyOp2 : public CrossEntropyOpBase {
242
+ public:
243
+ using CrossEntropyOpBase::CrossEntropyOpBase;
244
+
245
+ void InferShape (framework::InferShapeContext* ctx) const override {
246
+ CrossEntropyOpBase::InferShape (ctx);
247
+
248
+ PADDLE_ENFORCE (ctx->HasOutput (" XShape" ),
249
+ " Output(XShape) should be not null." );
250
+
251
+ PADDLE_ENFORCE (ctx->HasOutput (" MatchX" ),
252
+ " Output(MatchX) should be not null." );
253
+ auto x_dims = ctx->GetInputDim (" X" );
254
+ auto x_dims_vec = framework::vectorize (x_dims);
255
+ x_dims_vec.push_back (0 );
256
+ ctx->SetOutputDim (" XShape" , framework::make_ddim (x_dims_vec));
257
+ x_dims[x_dims.size () - 1 ] = 1 ;
258
+ ctx->SetOutputDim (" MatchX" , x_dims);
259
+ ctx->ShareLoD (" X" , /* ->*/ " XShape" );
260
+ }
261
+
206
262
protected:
207
- std::unordered_map<std::string, std::string> GetInputOutputWithSameType ()
208
- const override {
209
- return std::unordered_map<std::string, std::string>{{" X" , /* ->*/ " Y" }};
263
+ bool IsSoftLabel (framework::InferShapeContext* ctx) const override {
264
+ return false ;
265
+ }
266
+ };
267
+
268
+ class CrossEntropyGradientOp2 : public CrossEntropyGradientOpBase {
269
+ public:
270
+ using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase;
271
+ void InferShape (framework::InferShapeContext* ctx) const override {
272
+ PADDLE_ENFORCE (ctx->HasInput (" MatchX" ), " Input(MatchX) must exist" );
273
+ CrossEntropyGradientOpBase::InferShape (ctx);
274
+ }
275
+
276
+ protected:
277
+ virtual framework::DDim GetXDim (framework::InferShapeContext* ctx) const {
278
+ auto x_shape = ctx->GetInputDim (" XShape" );
279
+ return framework::DDim (x_shape.Get (), x_shape.size () - 1 );
280
+ }
281
+
282
+ virtual const char * VarNameWithXLoD () const { return " XShape" ; }
283
+
284
+ virtual bool IsSoftLabel (framework::InferShapeContext* ctx) const {
285
+ return false ;
210
286
}
211
287
};
288
+
289
+ class CrossEntropyOpMaker2 : public framework ::OpProtoAndCheckerMaker {
290
+ public:
291
+ void Make () override {
292
+ AddInput (" X" ,
293
+ " (Tensor, default Tensor<float>), a tensor whose last dimension "
294
+ " size is equal to the number of classes. This input is a "
295
+ " probability computed by the previous operator, which is almost "
296
+ " always the result of a softmax operator." );
297
+ AddInput (
298
+ " Label" ,
299
+ " (Tensor), the tensor which represents the ground truth. It has the "
300
+ " same shape with 'X' except the last dimension. One hot Tensor." );
301
+ AddOutput (" Y" ,
302
+ " (Tensor, default Tensor<float>), a tensor whose shape is same "
303
+ " with 'X' except that the last dimension size is 1. It "
304
+ " represents the cross entropy loss." );
305
+ AddOutput (" XShape" , " Temporaily variable to save shape and LoD of X." );
306
+ AddOutput (" MatchX" ,
307
+ " X value that matches label, used for gradient computation." );
308
+ AddAttr<int >(" ignore_index" ,
309
+ " (int, default -100), Specifies a target value that is"
310
+ " ignored and does not contribute to the input gradient."
311
+ " Only valid if soft_label is set to False" )
312
+ .SetDefault (-100 );
313
+ AddComment (R"DOC(
314
+ Hard-label CrossEntropy Operator.
315
+
316
+ The input 'X' and 'Label' will first be logically flattened to 2-D matrixs.
317
+ The matrix's second dimension(row length) is as same as the original last
318
+ dimension, and the first dimension(column length) is the product of all other
319
+ original dimensions. Then the softmax computation will take palce on each raw
320
+ of flattened matrixs.
321
+
322
+ Only support hard label.
323
+
324
+ Both the input X and Label can carry the LoD (Level of Details) information,
325
+ or not. But the output only shares the LoD information with input X.
326
+
327
+ )DOC" );
328
+ }
329
+ };
330
+
331
+ class CrossEntropyGradOpDescMaker2 : public framework ::SingleGradOpDescMaker {
332
+ public:
333
+ using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
334
+
335
+ protected:
336
+ std::unique_ptr<framework::OpDesc> Apply () const override {
337
+ std::unique_ptr<framework::OpDesc> op (new framework::OpDesc ());
338
+ op->SetType (" cross_entropy_grad2" );
339
+ op->SetInput (" Label" , Input (" Label" ));
340
+ op->SetInput (" MatchX" , Output (" MatchX" ));
341
+ op->SetInput (" XShape" , Output (" XShape" ));
342
+ op->SetInput (framework::GradVarName (" Y" ), OutputGrad (" Y" ));
343
+ op->SetOutput (framework::GradVarName (" X" ), InputGrad (" X" ));
344
+ op->SetAttrMap (Attrs ());
345
+ return op;
346
+ }
347
+ };
348
+
212
349
} // namespace operators
213
350
} // namespace paddle
214
351
215
352
namespace ops = paddle::operators;
216
353
using CPUCtx = paddle::platform::CPUDeviceContext;
217
354
218
- REGISTER_OPERATOR (cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker ,
219
- ops::CrossEntropyOpInferVarType,
355
+ REGISTER_OPERATOR (cross_entropy, ops::CrossEntropyOpBase ,
356
+ ops::CrossEntropyOpMaker, ops:: CrossEntropyOpInferVarType,
220
357
paddle::framework::DefaultGradOpDescMaker<true >);
221
358
REGISTER_OPERATOR (cross_entropy_grad, ops::CrossEntropyGradientOp);
222
359
REGISTER_OP_CPU_KERNEL (cross_entropy, ops::CrossEntropyOpKernel<CPUCtx, float >,
223
360
ops::CrossEntropyOpKernel<CPUCtx, double >);
224
361
REGISTER_OP_CPU_KERNEL (cross_entropy_grad,
225
362
ops::CrossEntropyGradientOpKernel<CPUCtx, float >,
226
363
ops::CrossEntropyGradientOpKernel<CPUCtx, double >);
364
+
365
+ REGISTER_OPERATOR (cross_entropy2, ops::CrossEntropyOp2,
366
+ ops::CrossEntropyOpMaker2, ops::CrossEntropyOpInferVarType,
367
+ ops::CrossEntropyGradOpDescMaker2);
368
+ REGISTER_OPERATOR (cross_entropy_grad2, ops::CrossEntropyGradientOp2);
369
+ REGISTER_OP_CPU_KERNEL (cross_entropy2,
370
+ ops::CrossEntropyOpKernel2<CPUCtx, float >,
371
+ ops::CrossEntropyOpKernel2<CPUCtx, double >);
372
+ REGISTER_OP_CPU_KERNEL (cross_entropy_grad2,
373
+ ops::CrossEntropyGradientOpKernel2<CPUCtx, float >,
374
+ ops::CrossEntropyGradientOpKernel2<CPUCtx, double >);
0 commit comments