Skip to content

Commit 75ae426

Browse files
committed
Merge branch 'feature/change_op_kernel_to_func' into feature/fix_reshape_op_size
2 parents 1ce478f + 3b00ed8 commit 75ae426

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

paddle/fluid/framework/op_registry.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,9 @@ class OpRegistry {
7676
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
7777
struct OpKernelRegistrarFunctor;
7878

79-
template <typename PlaceType, typename T, typename KernelType>
80-
inline void RegisterKernelClass(const char* op_type, const char* library_type) {
79+
template <typename PlaceType, typename T, typename Func>
80+
inline void RegisterKernelClass(const char* op_type, const char* library_type,
81+
Func func) {
8182
std::string library(library_type);
8283
std::string data_layout = "ANYLAYOUT";
8384
if (library == "MKLDNN") {
@@ -86,7 +87,7 @@ inline void RegisterKernelClass(const char* op_type, const char* library_type) {
8687
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
8788
StringToDataLayout(data_layout),
8889
StringToLibraryType(library_type));
89-
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KernelType());
90+
OperatorWithKernel::AllOpKernels()[op_type][key] = func;
9091
}
9192

9293
template <typename PlaceType, size_t I, typename... KernelTypes>
@@ -96,7 +97,10 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
9697

9798
void operator()(const char* op_type, const char* library_type) const {
9899
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
99-
RegisterKernelClass<PlaceType, T, KERNEL_TYPE>(op_type, library_type);
100+
RegisterKernelClass<PlaceType, T>(
101+
op_type, library_type, [](const framework::ExecutionContext& ctx) {
102+
KERNEL_TYPE().Compute(ctx);
103+
});
100104
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
101105
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
102106
func;
@@ -150,7 +154,10 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
150154
std::tuple<DataTypeAndKernelType...>>::type;
151155

152156
void operator()(const char* op_type, const char* library_type) const {
153-
RegisterKernelClass<PlaceType, T, KERNEL_TYPE>(op_type, library_type);
157+
RegisterKernelClass<PlaceType, T>(
158+
op_type, library_type, [](const framework::ExecutionContext& ctx) {
159+
KERNEL_TYPE().Compute(ctx);
160+
});
154161

155162
constexpr auto size =
156163
std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;

paddle/fluid/framework/operator.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
651651
dev_ctx = pool.Get(expected_kernel_key.place_);
652652
}
653653

654-
kernel_iter->second->Compute(ExecutionContext(*this, exec_scope, *dev_ctx));
654+
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx));
655655

656656
if (!transfered_inplace_vars.empty()) {
657657
// there is inplace variable has been transfered.

paddle/fluid/framework/operator.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,9 @@ class OpKernel : public OpKernelBase {
347347

348348
class OperatorWithKernel : public OperatorBase {
349349
public:
350+
using OpKernelFunc = std::function<void(const ExecutionContext&)>;
350351
using OpKernelMap =
351-
std::unordered_map<OpKernelType, std::unique_ptr<OpKernelBase>,
352-
OpKernelType::Hash>;
352+
std::unordered_map<OpKernelType, OpKernelFunc, OpKernelType::Hash>;
353353

354354
OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
355355
const VariableNameMap& outputs, const AttributeMap& attrs)

0 commit comments

Comments
 (0)