Skip to content

Commit 3b00ed8

Browse files
committed
Make Kernel registed as a function
1 parent 81f22bb commit 3b00ed8

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

paddle/fluid/framework/op_registry.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,10 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
9191
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
9292
StringToDataLayout(data_layout),
9393
StringToLibraryType(library_type));
94-
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE);
94+
OperatorWithKernel::AllOpKernels()[op_type][key] =
95+
[](const framework::ExecutionContext& ctx) {
96+
KERNEL_TYPE().Compute(ctx);
97+
};
9598

9699
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
97100
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>

paddle/fluid/framework/operator.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
651651
dev_ctx = pool.Get(expected_kernel_key.place_);
652652
}
653653

654-
kernel_iter->second->Compute(ExecutionContext(*this, exec_scope, *dev_ctx));
654+
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx));
655655

656656
if (!transfered_inplace_vars.empty()) {
657657
// there is inplace variable has been transfered.

paddle/fluid/framework/operator.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,9 @@ class OpKernel : public OpKernelBase {
347347

348348
class OperatorWithKernel : public OperatorBase {
349349
public:
350+
using OpKernelFunc = std::function<void(const ExecutionContext&)>;
350351
using OpKernelMap =
351-
std::unordered_map<OpKernelType, std::unique_ptr<OpKernelBase>,
352-
OpKernelType::Hash>;
352+
std::unordered_map<OpKernelType, OpKernelFunc, OpKernelType::Hash>;
353353

354354
OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
355355
const VariableNameMap& outputs, const AttributeMap& attrs)

0 commit comments

Comments
 (0)