Skip to content

Commit 685f037

Browse files
author
chengduo
authored
Merge pull request #8890 from chengduoZH/feature/fix_bug_of_elementwise
Add ElementwiseOpInferVarType for Elementwise_op
2 parents dd1244f + 53d19f5 commit 685f037

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

paddle/fluid/operators/elementwise_add_op.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,11 @@ class ElementwiseAddOpMaker : public ElementwiseOpMaker {
2929
} // namespace paddle
3030

3131
namespace ops = paddle::operators;
32-
REGISTER_OP(elementwise_add, ops::ElementwiseOp, ops::ElementwiseAddOpMaker,
33-
elementwise_add_grad, ops::ElementwiseOpGrad);
32+
REGISTER_OPERATOR(elementwise_add, ops::ElementwiseOp,
33+
ops::ElementwiseAddOpMaker, ops::ElementwiseOpInferVarType,
34+
paddle::framework::DefaultGradOpDescMaker<true>);
35+
REGISTER_OPERATOR(elementwise_add_grad, ops::ElementwiseOpGrad);
36+
3437
REGISTER_OP_CPU_KERNEL(
3538
elementwise_add,
3639
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>,

paddle/fluid/operators/elementwise_op.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@ class ElementwiseOp : public framework::OperatorWithKernel {
4141
}
4242
};
4343

44+
class ElementwiseOpInferVarType : public framework::VarTypeInference {
45+
public:
46+
void operator()(const framework::OpDesc& op_desc,
47+
framework::BlockDesc* block) const override {
48+
auto x_var = op_desc.Input("X")[0];
49+
auto out_var = op_desc.Output("Out")[0];
50+
block->Var(out_var)->SetType(block->Var(x_var)->GetType());
51+
}
52+
};
53+
4454
class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
4555
public:
4656
ElementwiseOpMaker(OpProto* proto, OpAttrChecker* op_checker)

0 commit comments

Comments
 (0)