@@ -13,21 +13,19 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#include " paddle/fluid/operators/cross_entropy_op.h"
16
- #include < memory>
17
16
#include < string>
18
17
#include < unordered_map>
19
18
20
19
namespace paddle {
21
20
namespace operators {
22
21
23
- class CrossEntropyOpBase : public framework ::OperatorWithKernel {
22
+ class CrossEntropyOp : public framework ::OperatorWithKernel {
24
23
public:
25
24
using framework::OperatorWithKernel::OperatorWithKernel;
26
25
27
26
void InferShape (framework::InferShapeContext* ctx) const override {
28
27
PADDLE_ENFORCE (ctx->HasInput (" X" ), " Input(X) should be not null." );
29
28
PADDLE_ENFORCE (ctx->HasInput (" Label" ), " Input(Label) should be not null." );
30
-
31
29
PADDLE_ENFORCE (ctx->HasOutput (" Y" ), " Output(Y) should be not null." );
32
30
33
31
auto x_dims = ctx->GetInputDim (" X" );
@@ -46,8 +44,7 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel {
46
44
" Input(X) and Input(Label) shall have the same shape "
47
45
" except the last dimension." );
48
46
}
49
-
50
- if (IsSoftLabel (ctx)) {
47
+ if (ctx->Attrs ().Get <bool >(" soft_label" )) {
51
48
if (check) {
52
49
PADDLE_ENFORCE_EQ (x_dims[rank - 1 ], label_dims[rank - 1 ],
53
50
" If Attr(soft_label) == true, the last dimension of "
@@ -73,24 +70,21 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel {
73
70
return framework::OpKernelType (ctx.Input <Tensor>(" X" )->type (),
74
71
ctx.device_context ());
75
72
}
76
-
77
- virtual bool IsSoftLabel (framework::InferShapeContext* ctx) const {
78
- return ctx->Attrs ().Get <bool >(" soft_label" );
79
- }
80
73
};
81
74
82
- class CrossEntropyGradientOpBase : public framework ::OperatorWithKernel {
75
+ class CrossEntropyGradientOp : public framework ::OperatorWithKernel {
83
76
public:
84
77
using framework::OperatorWithKernel::OperatorWithKernel;
85
78
86
- void InferShape (framework::InferShapeContext* ctx) const {
79
+ void InferShape (framework::InferShapeContext* ctx) const override {
80
+ PADDLE_ENFORCE (ctx->HasInput (" X" ), " Input(X) should be not null." );
87
81
PADDLE_ENFORCE (ctx->HasInput (" Label" ), " Input(Label) should be not null." );
88
82
PADDLE_ENFORCE (ctx->HasInput (framework::GradVarName (" Y" )),
89
83
" Input(Y@GRAD) shoudl be not null." );
90
84
PADDLE_ENFORCE (ctx->HasOutput (framework::GradVarName (" X" )),
91
85
" Output(X@GRAD) should be not null." );
92
86
93
- auto x_dims = GetXDim ( ctx);
87
+ auto x_dims = ctx-> GetInputDim ( " X " );
94
88
auto label_dims = ctx->GetInputDim (" Label" );
95
89
auto dy_dims = ctx->GetInputDim (framework::GradVarName (" Y" ));
96
90
int rank = x_dims.size ();
@@ -115,7 +109,9 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
115
109
" The Input(X) and Input(Y@Grad) should have the same "
116
110
" shape except the last dimension." );
117
111
}
118
- if (IsSoftLabel (ctx)) {
112
+ PADDLE_ENFORCE_EQ (dy_dims[rank - 1 ], 1 ,
113
+ " The last dimension of Input(Y@Grad) should be 1." );
114
+ if (ctx->Attrs ().Get <bool >(" soft_label" )) {
119
115
if (check) {
120
116
PADDLE_ENFORCE_EQ (
121
117
x_dims[rank - 1 ], label_dims[rank - 1 ],
@@ -128,39 +124,16 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
128
124
" Input(Label) should be 1." );
129
125
}
130
126
ctx->SetOutputDim (framework::GradVarName (" X" ), x_dims);
131
- PADDLE_ENFORCE_EQ (dy_dims[rank - 1 ], 1 ,
132
- " The last dimension of Input(Y@Grad) should be 1." );
133
- ctx->SetOutputDim (framework::GradVarName (" X" ), x_dims);
134
- ctx->ShareLoD (VarNameWithXLoD (), framework::GradVarName (" X" ));
127
+ ctx->ShareLoD (" X" , framework::GradVarName (" X" ));
135
128
}
136
129
137
130
protected:
138
131
// Explicitly set that the data type of computation kernel of cross_entropy
139
132
// is determined by its input "X".
140
133
framework::OpKernelType GetExpectedKernelType (
141
134
const framework::ExecutionContext& ctx) const override {
142
- return framework::OpKernelType (
143
- ctx.Input <Tensor>(framework::GradVarName (" Y" ))->type (),
144
- ctx.device_context ());
145
- }
146
-
147
- virtual framework::DDim GetXDim (framework::InferShapeContext* ctx) const {
148
- return ctx->GetInputDim (" X" );
149
- }
150
-
151
- virtual const char * VarNameWithXLoD () const { return " X" ; }
152
-
153
- virtual bool IsSoftLabel (framework::InferShapeContext* ctx) const {
154
- return ctx->Attrs ().Get <bool >(" soft_label" );
155
- }
156
- };
157
-
158
- class CrossEntropyOpInferVarType
159
- : public framework::PassInDtypeAndVarTypeToOutput {
160
- protected:
161
- std::unordered_map<std::string, std::string> GetInputOutputWithSameType ()
162
- const override {
163
- return std::unordered_map<std::string, std::string>{{" X" , /* ->*/ " Y" }};
135
+ return framework::OpKernelType (ctx.Input <Tensor>(" X" )->type (),
136
+ ctx.device_context ());
164
137
}
165
138
};
166
139
@@ -228,137 +201,26 @@ or not. But the output only shares the LoD information with input X.
228
201
}
229
202
};
230
203
231
- class CrossEntropyGradientOp : public CrossEntropyGradientOpBase {
232
- public:
233
- using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase;
234
-
235
- void InferShape (framework::InferShapeContext* ctx) const override {
236
- PADDLE_ENFORCE (ctx->HasInput (" X" ), " Input(X) should be not null." );
237
- CrossEntropyGradientOpBase::InferShape (ctx);
238
- }
239
- };
240
-
241
- class CrossEntropyOp2 : public CrossEntropyOpBase {
242
- public:
243
- using CrossEntropyOpBase::CrossEntropyOpBase;
244
-
245
- void InferShape (framework::InferShapeContext* ctx) const override {
246
- CrossEntropyOpBase::InferShape (ctx);
247
-
248
- PADDLE_ENFORCE (ctx->HasOutput (" XShape" ),
249
- " Output(XShape) should be not null." );
250
-
251
- auto x_dims = ctx->GetInputDim (" X" );
252
- auto x_dims_vec = framework::vectorize (x_dims);
253
- x_dims_vec.push_back (0 );
254
- ctx->SetOutputDim (" XShape" , framework::make_ddim (x_dims_vec));
255
- ctx->ShareLoD (" X" , /* ->*/ " XShape" );
256
- }
257
-
258
- protected:
259
- bool IsSoftLabel (framework::InferShapeContext* ctx) const override {
260
- return false ;
261
- }
262
- };
263
-
264
- class CrossEntropyGradientOp2 : public CrossEntropyGradientOpBase {
265
- public:
266
- using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase;
267
-
268
- protected:
269
- virtual framework::DDim GetXDim (framework::InferShapeContext* ctx) const {
270
- auto x_shape = ctx->GetInputDim (" XShape" );
271
- return framework::DDim (x_shape.Get (), x_shape.size () - 1 );
272
- }
273
-
274
- virtual const char * VarNameWithXLoD () const { return " XShape" ; }
275
-
276
- virtual bool IsSoftLabel (framework::InferShapeContext* ctx) const {
277
- return false ;
278
- }
279
- };
280
-
281
- class CrossEntropyOpMaker2 : public framework ::OpProtoAndCheckerMaker {
282
- public:
283
- void Make () override {
284
- AddInput (" X" ,
285
- " (Tensor, default Tensor<float>), a tensor whose last dimension "
286
- " size is equal to the number of classes. This input is a "
287
- " probability computed by the previous operator, which is almost "
288
- " always the result of a softmax operator." );
289
- AddInput (
290
- " Label" ,
291
- " (Tensor), the tensor which represents the ground truth. It has the "
292
- " same shape with 'X' except the last dimension. One hot Tensor." );
293
- AddOutput (" Y" ,
294
- " (Tensor, default Tensor<float>), a tensor whose shape is same "
295
- " with 'X' except that the last dimension size is 1. It "
296
- " represents the cross entropy loss." );
297
- AddOutput (" XShape" , " Temporaily variable to save shape and LoD of X." );
298
- AddAttr<int >(" ignore_index" ,
299
- " (int, default -100), Specifies a target value that is"
300
- " ignored and does not contribute to the input gradient."
301
- " Only valid if soft_label is set to False" )
302
- .SetDefault (-100 );
303
- AddComment (R"DOC(
304
- Hard-label CrossEntropy Operator.
305
-
306
- The input 'X' and 'Label' will first be logically flattened to 2-D matrixs.
307
- The matrix's second dimension(row length) is as same as the original last
308
- dimension, and the first dimension(column length) is the product of all other
309
- original dimensions. Then the softmax computation will take palce on each raw
310
- of flattened matrixs.
311
-
312
- Only support hard label.
313
-
314
- Both the input X and Label can carry the LoD (Level of Details) information,
315
- or not. But the output only shares the LoD information with input X.
316
-
317
- )DOC" );
318
- }
319
- };
320
-
321
- class CrossEntropyGradOpDescMaker2 : public framework ::SingleGradOpDescMaker {
322
- public:
323
- using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
324
-
204
+ class CrossEntropyOpInferVarType
205
+ : public framework::PassInDtypeAndVarTypeToOutput {
325
206
protected:
326
- std::unique_ptr<framework::OpDesc> Apply () const override {
327
- std::unique_ptr<framework::OpDesc> op (new framework::OpDesc ());
328
- op->SetType (" cross_entropy_grad2" );
329
- op->SetInput (" Label" , Input (" Label" ));
330
- op->SetInput (" Y" , Output (" Y" ));
331
- op->SetInput (" XShape" , Output (" XShape" ));
332
- op->SetInput (framework::GradVarName (" Y" ), OutputGrad (" Y" ));
333
- op->SetOutput (framework::GradVarName (" X" ), InputGrad (" X" ));
334
- op->SetAttrMap (Attrs ());
335
- return op;
207
+ std::unordered_map<std::string, std::string> GetInputOutputWithSameType ()
208
+ const override {
209
+ return std::unordered_map<std::string, std::string>{{" X" , /* ->*/ " Y" }};
336
210
}
337
211
};
338
-
339
212
} // namespace operators
340
213
} // namespace paddle
341
214
342
215
namespace ops = paddle::operators;
343
216
using CPUCtx = paddle::platform::CPUDeviceContext;
344
217
345
- REGISTER_OPERATOR (cross_entropy, ops::CrossEntropyOpBase ,
346
- ops::CrossEntropyOpMaker, ops:: CrossEntropyOpInferVarType,
218
+ REGISTER_OPERATOR (cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker ,
219
+ ops::CrossEntropyOpInferVarType,
347
220
paddle::framework::DefaultGradOpDescMaker<true >);
348
221
REGISTER_OPERATOR (cross_entropy_grad, ops::CrossEntropyGradientOp);
349
222
REGISTER_OP_CPU_KERNEL (cross_entropy, ops::CrossEntropyOpKernel<CPUCtx, float >,
350
223
ops::CrossEntropyOpKernel<CPUCtx, double >);
351
224
REGISTER_OP_CPU_KERNEL (cross_entropy_grad,
352
225
ops::CrossEntropyGradientOpKernel<CPUCtx, float >,
353
226
ops::CrossEntropyGradientOpKernel<CPUCtx, double >);
354
-
355
- REGISTER_OPERATOR (cross_entropy2, ops::CrossEntropyOp2,
356
- ops::CrossEntropyOpMaker2, ops::CrossEntropyOpInferVarType,
357
- ops::CrossEntropyGradOpDescMaker2);
358
- REGISTER_OPERATOR (cross_entropy_grad2, ops::CrossEntropyGradientOp2);
359
- REGISTER_OP_CPU_KERNEL (cross_entropy2,
360
- ops::CrossEntropyOpKernel2<CPUCtx, float >,
361
- ops::CrossEntropyOpKernel2<CPUCtx, double >);
362
- REGISTER_OP_CPU_KERNEL (cross_entropy_grad2,
363
- ops::CrossEntropyGradientOpKernel2<CPUCtx, float >,
364
- ops::CrossEntropyGradientOpKernel2<CPUCtx, double >);
0 commit comments