@@ -21,6 +21,7 @@ limitations under the License. */
21
21
#include < vector>
22
22
#include " paddle/fluid/framework/grad_op_desc_maker.h"
23
23
#include " paddle/fluid/framework/inplace_op_inference.h"
24
+ #include " paddle/fluid/framework/no_need_buffer_vars_inference.h"
24
25
#include " paddle/fluid/framework/op_info.h"
25
26
#include " paddle/fluid/framework/op_proto_maker.h"
26
27
#include " paddle/fluid/framework/operator.h"
@@ -36,27 +37,86 @@ enum OpInfoFillType {
36
37
kGradOpDescMaker = 2 ,
37
38
kVarTypeInference = 3 ,
38
39
kShapeInference = 4 ,
39
- kInplaceOpInference = 5
40
+ kInplaceOpInference = 5 ,
41
+ kNoNeedBufferVarsInference = 6 ,
42
+ kUnknown = -1
40
43
};
41
44
45
+ namespace internal {
46
+ template <typename T, OpInfoFillType kType >
47
+ struct TypePair {
48
+ using Type = T;
49
+ static constexpr OpInfoFillType kFillType = kType ;
50
+ };
51
+
52
+ using OpRegistryClasses = std::tuple< // NOLINT
53
+ TypePair<OperatorBase, kOperator >, // NOLINT
54
+ TypePair<OpProtoAndCheckerMaker, kOpProtoAndCheckerMaker >, // NOLINT
55
+ TypePair<GradOpDescMakerBase, kGradOpDescMaker >, // NOLINT
56
+ TypePair<VarTypeInference, kVarTypeInference >, // NOLINT
57
+ TypePair<InferShapeBase, kShapeInference >, // NOLINT
58
+ TypePair<InplaceOpInference, kInplaceOpInference >, // NOLINT
59
+ TypePair<NoNeedBufferVarsInference, kNoNeedBufferVarsInference > // NOLINT
60
+ >;
61
+
62
+ static constexpr int kOpRegistryClassNumber =
63
+ std::tuple_size<OpRegistryClasses>::value;
64
+
65
+ template <typename T, int kPos , bool kIsBounded /* = true*/ >
66
+ struct IsMatchedBaseTypeImpl {
67
+ using PairType = typename std::tuple_element<kPos , OpRegistryClasses>::type;
68
+ static constexpr bool kValue =
69
+ std::is_base_of<typename PairType::Type, T>::value;
70
+ };
71
+
72
+ template <typename T, int kPos >
73
+ struct IsMatchedBaseTypeImpl <T, kPos , false > {
74
+ static constexpr bool kValue = false ;
75
+ };
76
+
77
+ template <typename T, int kPos >
78
+ static inline constexpr bool IsMatchedBaseType () {
79
+ return IsMatchedBaseTypeImpl<
80
+ T, kPos , (kPos >= 0 && kPos < kOpRegistryClassNumber )>::kValue ;
81
+ }
82
+
83
+ template <typename T, int kStart , int kEnd , bool kIsEnd , bool kIsMatched >
84
+ struct OpInfoFillTypeGetterImpl {};
85
+
86
+ // This case should not happen
87
+ template <typename T, int kStart , int kEnd >
88
+ struct OpInfoFillTypeGetterImpl <T, kStart , kEnd , true , true > {};
89
+
90
+ template <typename T, int kStart , int kEnd >
91
+ struct OpInfoFillTypeGetterImpl <T, kStart , kEnd , true , false > {
92
+ static constexpr OpInfoFillType kType = kUnknown ;
93
+ };
94
+
95
+ template <typename T, int kStart , int kEnd >
96
+ struct OpInfoFillTypeGetterImpl <T, kStart , kEnd , false , false > {
97
+ static constexpr OpInfoFillType kType =
98
+ OpInfoFillTypeGetterImpl<T, kStart + 1 , kEnd , kStart + 1 == kEnd ,
99
+ IsMatchedBaseType<T, kStart + 1 >()>::kType ;
100
+ };
101
+
102
+ template <typename T, int kStart , int kEnd >
103
+ struct OpInfoFillTypeGetterImpl <T, kStart , kEnd , false , true > {
104
+ using PairType = typename std::tuple_element<kStart , OpRegistryClasses>::type;
105
+ static constexpr OpInfoFillType kType = PairType::kFillType ;
106
+ };
107
+
108
+ template <typename T>
109
+ using OpInfoFillTypeGetter =
110
+ OpInfoFillTypeGetterImpl<T, 0 , kOpRegistryClassNumber ,
111
+ kOpRegistryClassNumber == 0 ,
112
+ IsMatchedBaseType<T, 0 >()>;
113
+
114
+ } // namespace internal
115
+
42
116
template <typename T>
43
117
struct OpInfoFillTypeID {
44
118
static constexpr OpInfoFillType ID () {
45
- return std::is_base_of<OperatorBase, T>::value
46
- ? kOperator
47
- : (std::is_base_of<OpProtoAndCheckerMaker, T>::value
48
- ? kOpProtoAndCheckerMaker
49
- : (std::is_base_of<GradOpDescMakerBase, T>::value
50
- ? kGradOpDescMaker
51
- : (std::is_base_of<VarTypeInference, T>::value
52
- ? kVarTypeInference
53
- : (std::is_base_of<InferShapeBase, T>::value
54
- ? kShapeInference
55
- : (std::is_base_of<
56
- InplaceOpInference, T>::value
57
- ? kInplaceOpInference
58
- : static_cast <OpInfoFillType>(
59
- -1 ))))));
119
+ return internal::OpInfoFillTypeGetter<T>::kType ;
60
120
}
61
121
};
62
122
@@ -156,6 +216,18 @@ struct OpInfoFiller<T, kInplaceOpInference> {
156
216
}
157
217
};
158
218
219
+ template <typename T>
220
+ struct OpInfoFiller <T, kNoNeedBufferVarsInference > {
221
+ void operator ()(const char * op_type, OpInfo* info) const {
222
+ info->infer_no_need_buffer_vars_ = [](const VariableNameMap& inputs,
223
+ const VariableNameMap& outputs,
224
+ const AttributeMap& attrs) {
225
+ T infer (inputs, outputs, attrs);
226
+ return infer ();
227
+ };
228
+ }
229
+ };
230
+
159
231
} // namespace details
160
232
161
233
} // namespace framework
0 commit comments