@@ -76,23 +76,31 @@ class OpRegistry {
76
76
template <typename PlaceType, bool at_end, size_t I, typename ... KernelType>
77
77
struct OpKernelRegistrarFunctor ;
78
78
79
+ template <typename PlaceType, typename T, typename Func>
80
+ inline void RegisterKernelClass (const char * op_type, const char * library_type,
81
+ Func func) {
82
+ std::string library (library_type);
83
+ std::string data_layout = " ANYLAYOUT" ;
84
+ if (library == " MKLDNN" ) {
85
+ data_layout = " MKLDNNLAYOUT" ;
86
+ }
87
+ OpKernelType key (ToDataType (std::type_index (typeid (T))), PlaceType (),
88
+ StringToDataLayout (data_layout),
89
+ StringToLibraryType (library_type));
90
+ OperatorWithKernel::AllOpKernels ()[op_type][key] = func;
91
+ }
92
+
79
93
template <typename PlaceType, size_t I, typename ... KernelTypes>
80
94
struct OpKernelRegistrarFunctor <PlaceType, false , I, KernelTypes...> {
81
95
using KERNEL_TYPE =
82
96
typename std::tuple_element<I, std::tuple<KernelTypes...>>::type;
83
97
84
98
void operator ()(const char * op_type, const char * library_type) const {
85
99
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
86
- std::string library (library_type);
87
- std::string data_layout = " ANYLAYOUT" ;
88
- if (library == " MKLDNN" ) {
89
- data_layout = " MKLDNNLAYOUT" ;
90
- }
91
- OpKernelType key (ToDataType (std::type_index (typeid (T))), PlaceType (),
92
- StringToDataLayout (data_layout),
93
- StringToLibraryType (library_type));
94
- OperatorWithKernel::AllOpKernels ()[op_type][key].reset (new KERNEL_TYPE);
95
-
100
+ RegisterKernelClass<PlaceType, T>(
101
+ op_type, library_type, [](const framework::ExecutionContext& ctx) {
102
+ KERNEL_TYPE ().Compute (ctx);
103
+ });
96
104
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
97
105
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1 , KernelTypes...>
98
106
func;
@@ -116,6 +124,47 @@ class OpKernelRegistrar : public Registrar {
116
124
}
117
125
};
118
126
127
+ template <typename PlaceType, bool at_end, size_t I, typename ... KernelType>
128
+ struct OpKernelRegistrarFunctorEx ;
129
+
130
+ template <typename PlaceType, typename ... DataTypeAndKernelType>
131
+ class OpKernelRegistrarEx : public Registrar {
132
+ public:
133
+ explicit OpKernelRegistrarEx (const char * op_type, const char * library_type) {
134
+ OpKernelRegistrarFunctorEx<PlaceType, false , 0 , DataTypeAndKernelType...>
135
+ func;
136
+ func (op_type, library_type);
137
+ }
138
+ };
139
+
140
+ template <typename PlaceType, size_t I, typename ... DataTypeAndKernelType>
141
+ struct OpKernelRegistrarFunctorEx <PlaceType, true , I,
142
+ DataTypeAndKernelType...> {
143
+ void operator ()(const char * op_type, const char * library_type) const {}
144
+ };
145
+
146
+ template <typename PlaceType, size_t I, typename ... DataTypeAndKernelType>
147
+ struct OpKernelRegistrarFunctorEx <PlaceType, false , I,
148
+ DataTypeAndKernelType...> {
149
+ using Functor =
150
+ typename std::tuple_element<I + 1 ,
151
+ std::tuple<DataTypeAndKernelType...>>::type;
152
+ using T =
153
+ typename std::tuple_element<I,
154
+ std::tuple<DataTypeAndKernelType...>>::type;
155
+
156
+ void operator ()(const char * op_type, const char * library_type) const {
157
+ RegisterKernelClass<PlaceType, T>(op_type, library_type, Functor ());
158
+
159
+ constexpr auto size =
160
+ std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;
161
+ OpKernelRegistrarFunctorEx<PlaceType, I + 2 >= size, I + 2 ,
162
+ DataTypeAndKernelType...>
163
+ func;
164
+ func (op_type, library_type);
165
+ }
166
+ };
167
+
119
168
/* *
120
169
* check if MACRO is used in GLOBAL NAMESPACE.
121
170
*/
@@ -174,6 +223,25 @@ class OpKernelRegistrar : public Registrar {
174
223
#define REGISTER_OP_CPU_KERNEL (op_type, ...) \
175
224
REGISTER_OP_KERNEL (op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
176
225
226
+ #define REGISTER_OP_KERNEL_EX (op_type, library_type, place_class, ...) \
227
+ STATIC_ASSERT_GLOBAL_NAMESPACE ( \
228
+ __reg_op_kernel_##op_type##_##library_type##__, \
229
+ " REGISTER_OP_KERNEL_EX must be called in global namespace" ); \
230
+ static ::paddle::framework::OpKernelRegistrarEx<place_class, __VA_ARGS__> \
231
+ __op_kernel_registrar_##op_type##_##library_type##__(#op_type, \
232
+ #library_type); \
233
+ int TouchOpKernelRegistrar_##op_type##_##library_type() { \
234
+ __op_kernel_registrar_##op_type##_##library_type##__.Touch (); \
235
+ return 0 ; \
236
+ }
237
+
238
+ #define REGISTER_OP_CUDA_KERNEL_FUNCTOR (op_type, ...) \
239
+ REGISTER_OP_KERNEL_EX (op_type, CUDA, ::paddle::platform::CUDAPlace, \
240
+ __VA_ARGS__)
241
+
242
+ #define REGISTER_OP_CPU_KERNEL_FUNCTOR (op_type, ...) \
243
+ REGISTER_OP_KERNEL_EX (op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
244
+
177
245
/* *
178
246
* Macro to mark what Operator and Kernel
179
247
* we will use and tell the compiler to
0 commit comments