Skip to content

Commit 6c6e638

Browse files
author
chengduo
authored
Add InferVarType for some op (#14201)
* add_infer_var_type test=develop * InferVarTypeHelper-> VarTypeInferenceHelper test=develop * PassInputTypeAndDTypeOnOutput test=develop * follow comment test=develop
1 parent 0b38822 commit 6c6e638

File tree

11 files changed

+124
-29
lines changed

11 files changed

+124
-29
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,8 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
259259
if (row_size >= 0) {
260260
ss << "[row_size=" << row_size << "]";
261261
}
262+
std::string dtype = GetDtype(*scope, output.second[i]);
263+
ss << ":" << dtype;
262264
ss << "[" << GetDims(*scope, var_name, true) << "]";
263265
ss << "(" << GetLoD(*scope, var_name) << ")";
264266
}

paddle/fluid/framework/var_type_inference.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#pragma once
16+
#include <string>
17+
#include "paddle/fluid/framework/block_desc.h"
18+
#include "paddle/fluid/framework/op_desc.h"
1619
#include "paddle/fluid/framework/type_defs.h"
1720

1821
namespace paddle {
@@ -24,5 +27,27 @@ class VarTypeInference {
2427
virtual void operator()(const OpDesc& op_desc, BlockDesc* block) const = 0;
2528
};
2629

30+
class PassInDtypeAndVarTypeToOutput : public framework::VarTypeInference {
31+
public:
32+
void operator()(const framework::OpDesc& op_desc,
33+
framework::BlockDesc* block) const final {
34+
auto in_out_var_names = this->GetInputOutputWithSameType();
35+
36+
for (auto& i_o_n : in_out_var_names) {
37+
auto& x_name = op_desc.Input(i_o_n.first).at(0);
38+
auto& out_name = op_desc.Output(i_o_n.second).at(0);
39+
40+
auto& x = block->FindRecursiveOrCreateVar(x_name);
41+
auto& out = block->FindRecursiveOrCreateVar(out_name);
42+
out.SetType(x.GetType());
43+
out.SetDataType(x.GetDataType());
44+
}
45+
}
46+
47+
protected:
48+
virtual std::unordered_map<std::string, std::string>
49+
GetInputOutputWithSameType() const = 0;
50+
};
51+
2752
} // namespace framework
2853
} // namespace paddle

paddle/fluid/operators/activation_op.cc

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,12 @@ class ActivationOp : public framework::OperatorWithKernel {
9191
}
9292
};
9393

94-
class ActivationOpInferVarType : public framework::VarTypeInference {
95-
public:
96-
void operator()(const framework::OpDesc& op_desc,
97-
framework::BlockDesc* block) const override {
98-
auto x_name = op_desc.Input("X")[0];
99-
auto out_name = op_desc.Output("Out")[0];
100-
auto& x = block->FindRecursiveOrCreateVar(x_name);
101-
auto& out = block->FindRecursiveOrCreateVar(out_name);
102-
out.SetType(x.GetType());
103-
out.SetDataType(x.GetDataType());
94+
class ActivationOpInferVarType
95+
: public framework::PassInDtypeAndVarTypeToOutput {
96+
protected:
97+
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
98+
const override {
99+
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
104100
}
105101
};
106102

paddle/fluid/operators/batch_norm_op.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,15 @@ The required data format for this layer is one of the following:
170170
}
171171
};
172172

173+
class BatchNormOpInferVarType
174+
: public framework::PassInDtypeAndVarTypeToOutput {
175+
protected:
176+
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
177+
const override {
178+
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
179+
}
180+
};
181+
173182
template <typename T>
174183
class BatchNormKernel<platform::CPUDeviceContext, T>
175184
: public framework::OpKernel<T> {
@@ -525,7 +534,7 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker {
525534

526535
namespace ops = paddle::operators;
527536
REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
528-
ops::BatchNormGradMaker);
537+
ops::BatchNormOpInferVarType, ops::BatchNormGradMaker);
529538
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp);
530539

531540
REGISTER_OP_CPU_KERNEL(

paddle/fluid/operators/conv_op.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,15 @@ The input(X) size and output(Out) size may be different.
224224
)DOC");
225225
}
226226

227+
class ConvOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
228+
protected:
229+
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
230+
const override {
231+
return std::unordered_map<std::string, std::string>{
232+
{"Input", /*->*/ "Output"}};
233+
}
234+
};
235+
227236
void Conv3DOpMaker::Make() {
228237
AddInput(
229238
"Input",
@@ -365,14 +374,17 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
365374

366375
namespace ops = paddle::operators;
367376
REGISTER_OPERATOR(conv2d, ops::ConvOp, ops::Conv2DOpMaker,
377+
ops::ConvOpInferVarType,
368378
paddle::framework::DefaultGradOpDescMaker<true>);
369379
REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad);
370380

371381
// depthwise convolution op
372382
REGISTER_OPERATOR(depthwise_conv2d, ops::ConvOp, ops::Conv2DOpMaker,
373383
paddle::framework::DefaultGradOpDescMaker<true>);
374384
REGISTER_OPERATOR(depthwise_conv2d_grad, ops::ConvOpGrad);
385+
375386
REGISTER_OPERATOR(conv3d, ops::ConvOp, ops::Conv3DOpMaker,
387+
ops::ConvOpInferVarType,
376388
paddle::framework::DefaultGradOpDescMaker<true>);
377389
REGISTER_OPERATOR(conv3d_grad, ops::ConvOpGrad);
378390

paddle/fluid/operators/cross_entropy_op.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/cross_entropy_op.h"
16+
#include <string>
1617

1718
namespace paddle {
1819
namespace operators {
@@ -179,13 +180,23 @@ or not. But the output only shares the LoD information with input X.
179180
)DOC");
180181
}
181182
};
183+
184+
class CrossEntropyOpInferVarType
185+
: public framework::PassInDtypeAndVarTypeToOutput {
186+
protected:
187+
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
188+
const override {
189+
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
190+
}
191+
};
182192
} // namespace operators
183193
} // namespace paddle
184194

185195
namespace ops = paddle::operators;
186196
using CPUCtx = paddle::platform::CPUDeviceContext;
187197

188198
REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker,
199+
ops::CrossEntropyOpInferVarType,
189200
paddle::framework::DefaultGradOpDescMaker<true>);
190201
REGISTER_OPERATOR(cross_entropy_grad, ops::CrossEntropyGradientOp);
191202
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<CPUCtx, float>,

paddle/fluid/operators/elementwise_op.h

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,12 @@ class ElementwiseOp : public framework::OperatorWithKernel {
7575
}
7676
};
7777

78-
class ElementwiseOpInferVarType : public framework::VarTypeInference {
79-
public:
80-
void operator()(const framework::OpDesc &op_desc,
81-
framework::BlockDesc *block) const override {
82-
auto x_name = op_desc.Input("X")[0];
83-
auto out_name = op_desc.Output("Out")[0];
84-
auto &x = block->FindRecursiveOrCreateVar(x_name);
85-
auto &out = block->FindRecursiveOrCreateVar(out_name);
86-
out.SetType(x.GetType());
87-
out.SetDataType(x.GetDataType());
78+
class ElementwiseOpInferVarType
79+
: public framework::PassInDtypeAndVarTypeToOutput {
80+
protected:
81+
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
82+
const override {
83+
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
8884
}
8985
};
9086

paddle/fluid/operators/mean_op.cc

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/mean_op.h"
16-
16+
#include <string>
1717
namespace paddle {
1818
namespace operators {
1919

@@ -42,6 +42,14 @@ Mean Operator calculates the mean of all elements in X.
4242
}
4343
};
4444

45+
class MeanOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
46+
protected:
47+
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
48+
const override {
49+
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
50+
}
51+
};
52+
4553
class MeanGradOp : public framework::OperatorWithKernel {
4654
public:
4755
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -50,6 +58,14 @@ class MeanGradOp : public framework::OperatorWithKernel {
5058
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
5159
ctx->ShareLoD("X", framework::GradVarName("X"));
5260
}
61+
62+
framework::OpKernelType GetExpectedKernelType(
63+
const framework::ExecutionContext& ctx) const override {
64+
auto input_data_type =
65+
framework::ToDataType(ctx.Input<Tensor>("X")->type());
66+
67+
return framework::OpKernelType(input_data_type, ctx.GetPlace());
68+
}
5369
};
5470

5571
class MeanGradMaker : public framework::SingleGradOpDescMaker {
@@ -71,7 +87,8 @@ class MeanGradMaker : public framework::SingleGradOpDescMaker {
7187
} // namespace paddle
7288

7389
namespace ops = paddle::operators;
74-
REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanGradMaker);
90+
REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanOpInferVarType,
91+
ops::MeanGradMaker);
7592
REGISTER_OPERATOR(mean_grad, ops::MeanGradOp);
7693
REGISTER_OP_CPU_KERNEL(
7794
mean, ops::MeanKernel<paddle::platform::CPUDeviceContext, float>,

paddle/fluid/operators/mul_op.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,14 @@ or not. But the output only shares the LoD information with input $X$.
126126
}
127127
};
128128

129+
class MulOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
130+
protected:
131+
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
132+
const override {
133+
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
134+
}
135+
};
136+
129137
class MulGradOp : public framework::OperatorWithKernel {
130138
public:
131139
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -178,7 +186,8 @@ class MulOpGradMaker : public framework::SingleGradOpDescMaker {
178186
} // namespace paddle
179187

180188
namespace ops = paddle::operators;
181-
REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpGradMaker);
189+
REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpInferVarType,
190+
ops::MulOpGradMaker);
182191
REGISTER_OPERATOR(mul_grad, ops::MulGradOp);
183192
REGISTER_OP_CPU_KERNEL(
184193
mul, ops::MulKernel<paddle::platform::CPUDeviceContext, float>,

paddle/fluid/operators/pool_op.cc

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ int PoolOutputSize(int input_size, int filter_size, int padding, int stride,
4040
return output_size;
4141
}
4242

43-
void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
43+
void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
4444
PADDLE_ENFORCE(ctx->HasInput("X"), "X(Input) of Pooling should not be null.");
4545
PADDLE_ENFORCE(ctx->HasOutput("Out"),
4646
"Out(Output) of Pooling should not be null.");
@@ -81,7 +81,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
8181
}
8282

8383
framework::OpKernelType PoolOp::GetExpectedKernelType(
84-
const framework::ExecutionContext &ctx) const {
84+
const framework::ExecutionContext& ctx) const {
8585
framework::LibraryType library_{framework::LibraryType::kPlain};
8686
std::string data_format = ctx.Attr<std::string>("data_format");
8787
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
@@ -104,15 +104,15 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
104104
layout_, library_);
105105
}
106106

107-
void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {
107+
void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const {
108108
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
109109
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
110110
"Input(X@GRAD) should not be null.");
111111
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
112112
}
113113

114114
framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
115-
const framework::ExecutionContext &ctx) const {
115+
const framework::ExecutionContext& ctx) const {
116116
framework::LibraryType library_{framework::LibraryType::kPlain};
117117
std::string data_format = ctx.Attr<std::string>("data_format");
118118
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
@@ -262,6 +262,14 @@ The input(X) size and output(Out) size may be different.
262262
)DOC");
263263
}
264264

265+
class PoolOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
266+
protected:
267+
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
268+
const override {
269+
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
270+
}
271+
};
272+
265273
void Pool3dOpMaker::Make() {
266274
AddInput("X",
267275
"(Tensor) The input tensor of pooling operator. "
@@ -372,6 +380,7 @@ width, respectively. The input(X) size and output(Out) size may be different.
372380
namespace ops = paddle::operators;
373381

374382
REGISTER_OPERATOR(pool2d, ops::PoolOp, ops::Pool2dOpMaker,
383+
ops::PoolOpInferVarType,
375384
paddle::framework::DefaultGradOpDescMaker<true>);
376385
REGISTER_OPERATOR(pool2d_grad, ops::PoolOpGrad);
377386

@@ -383,6 +392,7 @@ REGISTER_OP_CPU_KERNEL(
383392
ops::PoolGradKernel<paddle::platform::CPUDeviceContext, double>);
384393

385394
REGISTER_OPERATOR(pool3d, ops::PoolOp, ops::Pool3dOpMaker,
395+
ops::PoolOpInferVarType,
386396
paddle::framework::DefaultGradOpDescMaker<true>);
387397
REGISTER_OPERATOR(pool3d_grad, ops::PoolOpGrad);
388398

0 commit comments

Comments
 (0)