File tree Expand file tree Collapse file tree 1 file changed +13
-0
lines changed Expand file tree Collapse file tree 1 file changed +13
-0
lines changed Original file line number Diff line number Diff line change @@ -42,6 +42,18 @@ class ElementwiseOp : public framework::OperatorWithKernel {
42
42
}
43
43
};
44
44
45
+ class ElementwiseOpInferVarType : public framework ::VarTypeInference {
46
+ public:
47
+ void operator ()(const framework::OpDesc& op_desc,
48
+ framework::BlockDesc* block) const override {
49
+ auto x_name = op_desc.Input (" X" )[0 ];
50
+ auto out_name = op_desc.Output (" Out" )[0 ];
51
+ auto & x = block->FindRecursiveOrCreateVar (x_name);
52
+ auto & out = block->FindRecursiveOrCreateVar (out_name);
53
+ out.SetType (x.GetType ());
54
+ }
55
+ };
56
+
45
57
class ElementwiseOpMaker : public framework ::OpProtoAndCheckerMaker {
46
58
public:
47
59
void Make () final {
@@ -138,5 +150,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
138
150
}; \
139
151
REGISTER_OPERATOR (op_type, ::paddle::operators::ElementwiseOp, \
140
152
__ElemwiseOp##op_type##Maker__, \
153
+ ::paddle::operators::ElementwiseOpInferVarType, \
141
154
::paddle::framework::DefaultGradOpDescMaker<true >); \
142
155
REGISTER_OPERATOR (op_type##_grad, ::paddle::operators::ElementwiseOpGrad)
You can’t perform that action at this time.
0 commit comments