Skip to content

Commit 6d2deed

Browse files
authored
Merge pull request #10814 from guoshengCS/fix-ElementwiseOpInferVarType
Fix ElementwiseOpInferVarType in elementwise_op
2 parents 5b2de50 + 01fdf17 commit 6d2deed

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

paddle/fluid/operators/elementwise_op.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,11 @@ class ElementwiseOpInferVarType : public framework::VarTypeInference {
4646
public:
4747
void operator()(const framework::OpDesc& op_desc,
4848
framework::BlockDesc* block) const override {
49-
auto x_var = op_desc.Input("X")[0];
50-
auto out_var = op_desc.Output("Out")[0];
51-
block->Var(out_var)->SetType(block->Var(x_var)->GetType());
49+
auto x_name = op_desc.Input("X")[0];
50+
auto out_name = op_desc.Output("Out")[0];
51+
auto& x = block->FindRecursiveOrCreateVar(x_name);
52+
auto& out = block->FindRecursiveOrCreateVar(out_name);
53+
out.SetType(x.GetType());
5254
}
5355
};
5456

0 commit comments

Comments
 (0)