Skip to content

Commit 01fdf17

Browse files
committed
Fix ElementwiseOpInferVarType in elementwise_op to use the default InferVarType to find var recursively
1 parent f176a9c commit 01fdf17

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

paddle/fluid/operators/elementwise_op.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,18 @@ class ElementwiseOp : public framework::OperatorWithKernel {
4242
}
4343
};
4444

45+
class ElementwiseOpInferVarType : public framework::VarTypeInference {
46+
public:
47+
void operator()(const framework::OpDesc& op_desc,
48+
framework::BlockDesc* block) const override {
49+
auto x_name = op_desc.Input("X")[0];
50+
auto out_name = op_desc.Output("Out")[0];
51+
auto& x = block->FindRecursiveOrCreateVar(x_name);
52+
auto& out = block->FindRecursiveOrCreateVar(out_name);
53+
out.SetType(x.GetType());
54+
}
55+
};
56+
4557
class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
4658
public:
4759
void Make() final {
@@ -138,5 +150,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
138150
}; \
139151
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
140152
__ElemwiseOp##op_type##Maker__, \
153+
::paddle::operators::ElementwiseOpInferVarType, \
141154
::paddle::framework::DefaultGradOpDescMaker<true>); \
142155
REGISTER_OPERATOR(op_type##_grad, ::paddle::operators::ElementwiseOpGrad)

0 commit comments

Comments
 (0)