Skip to content

Commit 1c81301

Browse files
committed
Update activations for MKL-DNN
1 parent 35e5563 commit 1c81301

File tree

3 files changed

+41
-19
lines changed

3 files changed

+41
-19
lines changed

paddle/fluid/operators/activation_mkldnn_op.cc

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,11 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
5252
mkldnn::memory::format::nchw);
5353

5454
// create memory primitives
55-
auto src_memory =
55+
auto src_memory = std::make_shared<mkldnn::memory>(
5656
mkldnn::memory({data_md, mkldnn_engine},
57-
static_cast<void *>(const_cast<float *>(src_data)));
57+
static_cast<void *>(const_cast<float *>(src_data))));
58+
// save source memory to device context to be referred in backward path
59+
dev_ctx.SetBlob("InputX@eltwise_pd", src_memory);
5860
auto dst_memory =
5961
mkldnn::memory({data_md, mkldnn_engine},
6062
static_cast<void *>(const_cast<float *>(dst_data)));
@@ -69,7 +71,7 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
6971
forward_desc, mkldnn_engine);
7072
dev_ctx.SetBlob(key_eltwise_pd, forward_pd);
7173

72-
auto eltwise = mkldnn::eltwise_forward(*forward_pd, src_memory, dst_memory);
74+
auto eltwise = mkldnn::eltwise_forward(*forward_pd, *src_memory, dst_memory);
7375

7476
// push primitive to stream and wait until it's executed
7577
std::vector<mkldnn::primitive> pipeline = {eltwise};
@@ -83,8 +85,7 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
8385
const auto &mkldnn_engine = dev_ctx.GetEngine();
8486

8587
// get buffers
86-
const auto *x = ctx.template Input<Tensor>("X");
87-
const auto *src = x->template data<T>();
88+
const auto *x = ctx.template Input<Tensor>("Out");
8889

8990
auto *dout = ctx.template Input<Tensor>(framework::GradVarName("Out"));
9091
const auto *diff_dst = dout->template data<T>();
@@ -103,9 +104,11 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
103104
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
104105
mkldnn::memory::format::nchw);
105106

107+
// retrieve source memory from device context
108+
const std::shared_ptr<void> src_memory = dev_ctx.GetBlob("InputX@eltwise_pd");
109+
auto *p_src_memory = static_cast<mkldnn::memory *>(src_memory.get());
110+
106111
// create memory primitives
107-
auto src_memory = mkldnn::memory(
108-
{data_md, mkldnn_engine}, static_cast<void *>(const_cast<float *>(src)));
109112
auto diff_src_memory =
110113
mkldnn::memory({data_md, mkldnn_engine},
111114
static_cast<void *>(const_cast<float *>(diff_src)));
@@ -128,8 +131,8 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
128131
auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc(
129132
backward_desc, mkldnn_engine, *p_forward_pd);
130133

131-
auto eltwise_bwd = mkldnn::eltwise_backward(eltwise_bwd_prim_desc, src_memory,
132-
diff_dst_memory, diff_src_memory);
134+
auto eltwise_bwd = mkldnn::eltwise_backward(
135+
eltwise_bwd_prim_desc, *p_src_memory, diff_dst_memory, diff_src_memory);
133136

134137
// push primitive to stream and wait until it's executed
135138
std::vector<mkldnn::primitive> pipeline = {eltwise_bwd};

paddle/fluid/operators/activation_op.cc

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -458,21 +458,22 @@ namespace ops = paddle::operators;
458458

459459
#define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \
460460
__macro(Sigmoid, sigmoid); \
461-
__macro(Relu, relu); \
462461
__macro(Exp, exp); \
463-
__macro(Tanh, tanh); \
464462
__macro(Ceil, ceil); \
465463
__macro(Floor, floor); \
466-
__macro(Sqrt, sqrt); \
467464
__macro(SoftRelu, soft_relu); \
468465
__macro(Relu6, relu6); \
469466
__macro(Reciprocal, reciprocal); \
470467
__macro(HardSigmoid, hard_sigmoid);
471468

469+
#define FOR_EACH_MKLDNN_INPLACE_OP_FUNCTOR(__macro) \
470+
__macro(Relu, relu); \
471+
__macro(Tanh, tanh); \
472+
__macro(Sqrt, sqrt);
473+
472474
#define FOR_EACH_OP_FUNCTOR(__macro) \
473475
__macro(LogSigmoid, logsigmoid); \
474476
__macro(SoftShrink, softshrink); \
475-
__macro(Abs, abs); \
476477
__macro(Cos, cos); \
477478
__macro(Sin, sin); \
478479
__macro(Round, round); \
@@ -490,18 +491,32 @@ namespace ops = paddle::operators;
490491
__macro(Swish, swish); \
491492
__macro(ThresholdedRelu, thresholded_relu);
492493

494+
#define FOR_EACH_MKLDNN_OP_FUNCTOR(__macro) __macro(Abs, abs);
495+
493496
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
494497
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
495498
::paddle::operators::OP_NAME##OpMaker, \
496499
::paddle::operators::OP_NAME##GradMaker); \
497500
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
498501

502+
#define REGISTER_INPLACE_ACTIVATION_MKLDNN_OP(OP_NAME, KERNEL_TYPE) \
503+
REGISTER_OPERATOR(KERNEL_TYPE, ops::ActivationWithMKLDNNOp, \
504+
::paddle::operators::OP_NAME##OpMaker, \
505+
::paddle::operators::OP_NAME##GradMaker); \
506+
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationWithMKLDNNOpGrad)
507+
499508
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
500509
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
501510
::paddle::operators::OP_NAME##OpMaker, \
502511
::paddle::framework::DefaultGradOpDescMaker<true>); \
503512
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
504513

514+
#define REGISTER_ACTIVATION_MKLDNN_OP(OP_NAME, KERNEL_TYPE) \
515+
REGISTER_OPERATOR(KERNEL_TYPE, ops::ActivationWithMKLDNNOp, \
516+
::paddle::operators::OP_NAME##OpMaker, \
517+
::paddle::framework::DefaultGradOpDescMaker<true>); \
518+
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationWithMKLDNNOpGrad)
519+
505520
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
506521
REGISTER_OP_CPU_KERNEL( \
507522
act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
@@ -516,5 +531,7 @@ namespace ops = paddle::operators;
516531
ops::grad_functor<double>>);
517532

518533
FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP);
534+
FOR_EACH_MKLDNN_OP_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_OP);
519535
FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP);
536+
FOR_EACH_MKLDNN_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_MKLDNN_OP);
520537
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);

paddle/fluid/operators/mkldnn_activation_op.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#pragma once
16+
#include <string>
17+
1618
#include "paddle/fluid/framework/eigen.h"
1719
#include "paddle/fluid/framework/op_registry.h"
1820
#include "paddle/fluid/operators/detail/safe_ref.h"
@@ -61,9 +63,9 @@ class MKLDNNActivationGradKernel
6163
};
6264

6365
namespace { // NOLINT
64-
framework::OpKernelType GetKernelType(
65-
const framework::ExecutionContext& ctx,
66-
const framework::OperatorWithKernel& oper) {
66+
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
67+
const framework::OperatorWithKernel& oper,
68+
const std::string& name) {
6769
framework::LibraryType library{framework::LibraryType::kPlain};
6870
#ifdef PADDLE_WITH_MKLDNN
6971
if (library == framework::LibraryType::kPlain &&
@@ -73,7 +75,7 @@ framework::OpKernelType GetKernelType(
7375
#endif
7476
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
7577
return framework::OpKernelType(
76-
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
78+
framework::ToDataType(ctx.Input<framework::Tensor>(name)->type()),
7779
ctx.GetPlace(), layout, library);
7880
}
7981
} // anonymous namespace
@@ -89,7 +91,7 @@ class ActivationWithMKLDNNOp : public framework::OperatorWithKernel {
8991

9092
framework::OpKernelType GetExpectedKernelType(
9193
const framework::ExecutionContext& ctx) const override {
92-
return GetKernelType(ctx, *this);
94+
return GetKernelType(ctx, *this, "X");
9395
}
9496
};
9597

@@ -103,7 +105,7 @@ class ActivationWithMKLDNNOpGrad : public framework::OperatorWithKernel {
103105

104106
framework::OpKernelType GetExpectedKernelType(
105107
const framework::ExecutionContext& ctx) const override {
106-
return GetKernelType(ctx, *this);
108+
return GetKernelType(ctx, *this, "Out");
107109
}
108110
};
109111

0 commit comments

Comments
 (0)