Skip to content

Commit 0ca6465

Browse files
authored
Merge pull request #16210 from sneaxiy/rnn_mem_opt
Fix cross_entropy2_op numeric error
2 parents 79df026 + 3e03695 commit 0ca6465

File tree

12 files changed

+472
-53
lines changed

12 files changed

+472
-53
lines changed

paddle/fluid/operators/cross_entropy_op.cc

Lines changed: 167 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,21 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/cross_entropy_op.h"
16+
#include <memory>
1617
#include <string>
1718
#include <unordered_map>
1819

1920
namespace paddle {
2021
namespace operators {
2122

22-
class CrossEntropyOp : public framework::OperatorWithKernel {
23+
class CrossEntropyOpBase : public framework::OperatorWithKernel {
2324
public:
2425
using framework::OperatorWithKernel::OperatorWithKernel;
2526

2627
void InferShape(framework::InferShapeContext* ctx) const override {
2728
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
2829
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
30+
2931
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null.");
3032

3133
auto x_dims = ctx->GetInputDim("X");
@@ -44,7 +46,8 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
4446
"Input(X) and Input(Label) shall have the same shape "
4547
"except the last dimension.");
4648
}
47-
if (ctx->Attrs().Get<bool>("soft_label")) {
49+
50+
if (IsSoftLabel(ctx)) {
4851
if (check) {
4952
PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1],
5053
"If Attr(soft_label) == true, the last dimension of "
@@ -70,21 +73,24 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
7073
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
7174
ctx.device_context());
7275
}
76+
77+
virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const {
78+
return ctx->Attrs().Get<bool>("soft_label");
79+
}
7380
};
7481

75-
class CrossEntropyGradientOp : public framework::OperatorWithKernel {
82+
class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
7683
public:
7784
using framework::OperatorWithKernel::OperatorWithKernel;
7885

79-
void InferShape(framework::InferShapeContext* ctx) const override {
80-
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
86+
void InferShape(framework::InferShapeContext* ctx) const {
8187
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
8288
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
8389
"Input(Y@GRAD) shoudl be not null.");
8490
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
8591
"Output(X@GRAD) should be not null.");
8692

87-
auto x_dims = ctx->GetInputDim("X");
93+
auto x_dims = GetXDim(ctx);
8894
auto label_dims = ctx->GetInputDim("Label");
8995
auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y"));
9096
int rank = x_dims.size();
@@ -109,9 +115,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
109115
"The Input(X) and Input(Y@Grad) should have the same "
110116
"shape except the last dimension.");
111117
}
112-
PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1,
113-
"The last dimension of Input(Y@Grad) should be 1.");
114-
if (ctx->Attrs().Get<bool>("soft_label")) {
118+
if (IsSoftLabel(ctx)) {
115119
if (check) {
116120
PADDLE_ENFORCE_EQ(
117121
x_dims[rank - 1], label_dims[rank - 1],
@@ -124,16 +128,39 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
124128
"Input(Label) should be 1.");
125129
}
126130
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
127-
ctx->ShareLoD("X", framework::GradVarName("X"));
131+
PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1,
132+
"The last dimension of Input(Y@Grad) should be 1.");
133+
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
134+
ctx->ShareLoD(VarNameWithXLoD(), framework::GradVarName("X"));
128135
}
129136

130137
protected:
131138
// Explicitly set that the data type of computation kernel of cross_entropy
132139
// is determined by its input "X".
133140
framework::OpKernelType GetExpectedKernelType(
134141
const framework::ExecutionContext& ctx) const override {
135-
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
136-
ctx.device_context());
142+
return framework::OpKernelType(
143+
ctx.Input<Tensor>(framework::GradVarName("Y"))->type(),
144+
ctx.device_context());
145+
}
146+
147+
virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const {
148+
return ctx->GetInputDim("X");
149+
}
150+
151+
virtual const char* VarNameWithXLoD() const { return "X"; }
152+
153+
virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const {
154+
return ctx->Attrs().Get<bool>("soft_label");
155+
}
156+
};
157+
158+
class CrossEntropyOpInferVarType
159+
: public framework::PassInDtypeAndVarTypeToOutput {
160+
protected:
161+
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
162+
const override {
163+
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
137164
}
138165
};
139166

@@ -201,26 +228,147 @@ or not. But the output only shares the LoD information with input X.
201228
}
202229
};
203230

204-
class CrossEntropyOpInferVarType
205-
: public framework::PassInDtypeAndVarTypeToOutput {
231+
class CrossEntropyGradientOp : public CrossEntropyGradientOpBase {
232+
public:
233+
using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase;
234+
235+
void InferShape(framework::InferShapeContext* ctx) const override {
236+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
237+
CrossEntropyGradientOpBase::InferShape(ctx);
238+
}
239+
};
240+
241+
class CrossEntropyOp2 : public CrossEntropyOpBase {
242+
public:
243+
using CrossEntropyOpBase::CrossEntropyOpBase;
244+
245+
void InferShape(framework::InferShapeContext* ctx) const override {
246+
CrossEntropyOpBase::InferShape(ctx);
247+
248+
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
249+
"Output(XShape) should be not null.");
250+
251+
PADDLE_ENFORCE(ctx->HasOutput("MatchX"),
252+
"Output(MatchX) should be not null.");
253+
auto x_dims = ctx->GetInputDim("X");
254+
auto x_dims_vec = framework::vectorize(x_dims);
255+
x_dims_vec.push_back(0);
256+
ctx->SetOutputDim("XShape", framework::make_ddim(x_dims_vec));
257+
x_dims[x_dims.size() - 1] = 1;
258+
ctx->SetOutputDim("MatchX", x_dims);
259+
ctx->ShareLoD("X", /*->*/ "XShape");
260+
}
261+
206262
protected:
207-
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
208-
const override {
209-
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
263+
bool IsSoftLabel(framework::InferShapeContext* ctx) const override {
264+
return false;
265+
}
266+
};
267+
268+
class CrossEntropyGradientOp2 : public CrossEntropyGradientOpBase {
269+
public:
270+
using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase;
271+
void InferShape(framework::InferShapeContext* ctx) const override {
272+
PADDLE_ENFORCE(ctx->HasInput("MatchX"), "Input(MatchX) must exist");
273+
CrossEntropyGradientOpBase::InferShape(ctx);
274+
}
275+
276+
protected:
277+
virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const {
278+
auto x_shape = ctx->GetInputDim("XShape");
279+
return framework::DDim(x_shape.Get(), x_shape.size() - 1);
280+
}
281+
282+
virtual const char* VarNameWithXLoD() const { return "XShape"; }
283+
284+
virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const {
285+
return false;
210286
}
211287
};
288+
289+
class CrossEntropyOpMaker2 : public framework::OpProtoAndCheckerMaker {
290+
public:
291+
void Make() override {
292+
AddInput("X",
293+
"(Tensor, default Tensor<float>), a tensor whose last dimension "
294+
"size is equal to the number of classes. This input is a "
295+
"probability computed by the previous operator, which is almost "
296+
"always the result of a softmax operator.");
297+
AddInput(
298+
"Label",
299+
"(Tensor), the tensor which represents the ground truth. It has the "
300+
"same shape with 'X' except the last dimension. One hot Tensor.");
301+
AddOutput("Y",
302+
"(Tensor, default Tensor<float>), a tensor whose shape is same "
303+
"with 'X' except that the last dimension size is 1. It "
304+
"represents the cross entropy loss.");
305+
AddOutput("XShape", "Temporaily variable to save shape and LoD of X.");
306+
AddOutput("MatchX",
307+
"X value that matches label, used for gradient computation.");
308+
AddAttr<int>("ignore_index",
309+
"(int, default -100), Specifies a target value that is"
310+
"ignored and does not contribute to the input gradient."
311+
"Only valid if soft_label is set to False")
312+
.SetDefault(-100);
313+
AddComment(R"DOC(
314+
Hard-label CrossEntropy Operator.
315+
316+
The input 'X' and 'Label' will first be logically flattened to 2-D matrixs.
317+
The matrix's second dimension(row length) is as same as the original last
318+
dimension, and the first dimension(column length) is the product of all other
319+
original dimensions. Then the softmax computation will take palce on each raw
320+
of flattened matrixs.
321+
322+
Only support hard label.
323+
324+
Both the input X and Label can carry the LoD (Level of Details) information,
325+
or not. But the output only shares the LoD information with input X.
326+
327+
)DOC");
328+
}
329+
};
330+
331+
class CrossEntropyGradOpDescMaker2 : public framework::SingleGradOpDescMaker {
332+
public:
333+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
334+
335+
protected:
336+
std::unique_ptr<framework::OpDesc> Apply() const override {
337+
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
338+
op->SetType("cross_entropy_grad2");
339+
op->SetInput("Label", Input("Label"));
340+
op->SetInput("MatchX", Output("MatchX"));
341+
op->SetInput("XShape", Output("XShape"));
342+
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
343+
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
344+
op->SetAttrMap(Attrs());
345+
return op;
346+
}
347+
};
348+
212349
} // namespace operators
213350
} // namespace paddle
214351

215352
namespace ops = paddle::operators;
216353
using CPUCtx = paddle::platform::CPUDeviceContext;
217354

218-
REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker,
219-
ops::CrossEntropyOpInferVarType,
355+
REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOpBase,
356+
ops::CrossEntropyOpMaker, ops::CrossEntropyOpInferVarType,
220357
paddle::framework::DefaultGradOpDescMaker<true>);
221358
REGISTER_OPERATOR(cross_entropy_grad, ops::CrossEntropyGradientOp);
222359
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<CPUCtx, float>,
223360
ops::CrossEntropyOpKernel<CPUCtx, double>);
224361
REGISTER_OP_CPU_KERNEL(cross_entropy_grad,
225362
ops::CrossEntropyGradientOpKernel<CPUCtx, float>,
226363
ops::CrossEntropyGradientOpKernel<CPUCtx, double>);
364+
365+
REGISTER_OPERATOR(cross_entropy2, ops::CrossEntropyOp2,
366+
ops::CrossEntropyOpMaker2, ops::CrossEntropyOpInferVarType,
367+
ops::CrossEntropyGradOpDescMaker2);
368+
REGISTER_OPERATOR(cross_entropy_grad2, ops::CrossEntropyGradientOp2);
369+
REGISTER_OP_CPU_KERNEL(cross_entropy2,
370+
ops::CrossEntropyOpKernel2<CPUCtx, float>,
371+
ops::CrossEntropyOpKernel2<CPUCtx, double>);
372+
REGISTER_OP_CPU_KERNEL(cross_entropy_grad2,
373+
ops::CrossEntropyGradientOpKernel2<CPUCtx, float>,
374+
ops::CrossEntropyGradientOpKernel2<CPUCtx, double>);

paddle/fluid/operators/cross_entropy_op.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,13 @@ REGISTER_OP_CUDA_KERNEL(
2727
cross_entropy_grad, ops::CrossEntropyGradientOpKernel<CUDACtx, float>,
2828
ops::CrossEntropyGradientOpKernel<CUDACtx, double>,
2929
ops::CrossEntropyGradientOpKernel<CUDACtx, plat::float16>);
30+
31+
REGISTER_OP_CUDA_KERNEL(cross_entropy2,
32+
ops::CrossEntropyOpKernel2<CUDACtx, float>,
33+
ops::CrossEntropyOpKernel2<CUDACtx, double>,
34+
ops::CrossEntropyOpKernel2<CUDACtx, plat::float16>);
35+
36+
REGISTER_OP_CUDA_KERNEL(
37+
cross_entropy_grad2, ops::CrossEntropyGradientOpKernel2<CUDACtx, float>,
38+
ops::CrossEntropyGradientOpKernel2<CUDACtx, double>,
39+
ops::CrossEntropyGradientOpKernel2<CUDACtx, plat::float16>);

0 commit comments

Comments
 (0)