@@ -76,8 +76,9 @@ class OpRegistry {
76
76
template <typename PlaceType, bool at_end, size_t I, typename ... KernelType>
77
77
struct OpKernelRegistrarFunctor ;
78
78
79
- template <typename PlaceType, typename T, typename KernelType>
80
- inline void RegisterKernelClass (const char * op_type, const char * library_type) {
79
+ template <typename PlaceType, typename T, typename Func>
80
+ inline void RegisterKernelClass (const char * op_type, const char * library_type,
81
+ Func func) {
81
82
std::string library (library_type);
82
83
std::string data_layout = " ANYLAYOUT" ;
83
84
if (library == " MKLDNN" ) {
@@ -86,7 +87,7 @@ inline void RegisterKernelClass(const char* op_type, const char* library_type) {
86
87
OpKernelType key (ToDataType (std::type_index (typeid (T))), PlaceType (),
87
88
StringToDataLayout (data_layout),
88
89
StringToLibraryType (library_type));
89
- OperatorWithKernel::AllOpKernels ()[op_type][key]. reset ( new KernelType ()) ;
90
+ OperatorWithKernel::AllOpKernels ()[op_type][key] = func ;
90
91
}
91
92
92
93
template <typename PlaceType, size_t I, typename ... KernelTypes>
@@ -96,7 +97,10 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
96
97
97
98
void operator ()(const char * op_type, const char * library_type) const {
98
99
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
99
- RegisterKernelClass<PlaceType, T, KERNEL_TYPE>(op_type, library_type);
100
+ RegisterKernelClass<PlaceType, T>(
101
+ op_type, library_type, [](const framework::ExecutionContext& ctx) {
102
+ KERNEL_TYPE ().Compute (ctx);
103
+ });
100
104
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
101
105
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1 , KernelTypes...>
102
106
func;
@@ -150,7 +154,10 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
150
154
std::tuple<DataTypeAndKernelType...>>::type;
151
155
152
156
void operator ()(const char * op_type, const char * library_type) const {
153
- RegisterKernelClass<PlaceType, T, KERNEL_TYPE>(op_type, library_type);
157
+ RegisterKernelClass<PlaceType, T>(
158
+ op_type, library_type, [](const framework::ExecutionContext& ctx) {
159
+ KERNEL_TYPE ().Compute (ctx);
160
+ });
154
161
155
162
constexpr auto size =
156
163
std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;
0 commit comments