Skip to content

Commit a76d0dd

Browse files
committed
MKL-DNN activations improvements
1 parent 1c81301 commit a76d0dd

File tree

2 files changed

+34
-71
lines changed

2 files changed

+34
-71
lines changed

paddle/fluid/operators/activation_op.cc

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ namespace operators {
4141
\
4242
protected: \
4343
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { \
44-
auto *op = new ::paddle::framework::OpDesc(); \
44+
auto* op = new ::paddle::framework::OpDesc(); \
4545
op->SetType(#KERNEL_TYPE "_grad"); \
4646
op->SetInput("Out", Output("Out")); \
4747
op->SetInput(::paddle::framework::GradVarName("Out"), \
@@ -54,23 +54,50 @@ namespace operators {
5454
} \
5555
}
5656

57+
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
58+
const framework::OperatorWithKernel& oper,
59+
const std::string& name) {
60+
framework::LibraryType library{framework::LibraryType::kPlain};
61+
#ifdef PADDLE_WITH_MKLDNN
62+
auto it = oper.Attrs().find("use_mkldnn");
63+
if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
64+
platform::CanMKLDNNBeUsed(ctx)) {
65+
library = framework::LibraryType::kMKLDNN;
66+
}
67+
#endif
68+
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
69+
return framework::OpKernelType(
70+
framework::ToDataType(ctx.Input<framework::Tensor>(name)->type()),
71+
ctx.GetPlace(), layout, library);
72+
}
73+
5774
class ActivationOp : public framework::OperatorWithKernel {
5875
public:
5976
using framework::OperatorWithKernel::OperatorWithKernel;
6077

61-
void InferShape(framework::InferShapeContext *ctx) const override {
78+
void InferShape(framework::InferShapeContext* ctx) const override {
6279
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
6380
ctx->ShareLoD("X", /*->*/ "Out");
6481
}
82+
83+
framework::OpKernelType GetExpectedKernelType(
84+
const framework::ExecutionContext& ctx) const override {
85+
return GetKernelType(ctx, *this, "X");
86+
}
6587
};
6688

6789
class ActivationOpGrad : public framework::OperatorWithKernel {
6890
public:
6991
using framework::OperatorWithKernel::OperatorWithKernel;
7092

71-
void InferShape(framework::InferShapeContext *ctx) const override {
93+
void InferShape(framework::InferShapeContext* ctx) const override {
7294
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
7395
}
96+
97+
framework::OpKernelType GetExpectedKernelType(
98+
const framework::ExecutionContext& ctx) const override {
99+
return GetKernelType(ctx, *this, "Out");
100+
}
74101
};
75102

76103
__attribute__((unused)) constexpr char SigmoidDoc[] = R"DOC(
@@ -458,22 +485,21 @@ namespace ops = paddle::operators;
458485

459486
#define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \
460487
__macro(Sigmoid, sigmoid); \
488+
__macro(Relu, relu); \
461489
__macro(Exp, exp); \
490+
__macro(Tanh, tanh); \
462491
__macro(Ceil, ceil); \
463492
__macro(Floor, floor); \
493+
__macro(Sqrt, sqrt); \
464494
__macro(SoftRelu, soft_relu); \
465495
__macro(Relu6, relu6); \
466496
__macro(Reciprocal, reciprocal); \
467497
__macro(HardSigmoid, hard_sigmoid);
468498

469-
#define FOR_EACH_MKLDNN_INPLACE_OP_FUNCTOR(__macro) \
470-
__macro(Relu, relu); \
471-
__macro(Tanh, tanh); \
472-
__macro(Sqrt, sqrt);
473-
474499
#define FOR_EACH_OP_FUNCTOR(__macro) \
475500
__macro(LogSigmoid, logsigmoid); \
476501
__macro(SoftShrink, softshrink); \
502+
__macro(Abs, abs); \
477503
__macro(Cos, cos); \
478504
__macro(Sin, sin); \
479505
__macro(Round, round); \
@@ -491,32 +517,18 @@ namespace ops = paddle::operators;
491517
__macro(Swish, swish); \
492518
__macro(ThresholdedRelu, thresholded_relu);
493519

494-
#define FOR_EACH_MKLDNN_OP_FUNCTOR(__macro) __macro(Abs, abs);
495-
496520
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
497521
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
498522
::paddle::operators::OP_NAME##OpMaker, \
499523
::paddle::operators::OP_NAME##GradMaker); \
500524
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
501525

502-
#define REGISTER_INPLACE_ACTIVATION_MKLDNN_OP(OP_NAME, KERNEL_TYPE) \
503-
REGISTER_OPERATOR(KERNEL_TYPE, ops::ActivationWithMKLDNNOp, \
504-
::paddle::operators::OP_NAME##OpMaker, \
505-
::paddle::operators::OP_NAME##GradMaker); \
506-
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationWithMKLDNNOpGrad)
507-
508526
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
509527
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
510528
::paddle::operators::OP_NAME##OpMaker, \
511529
::paddle::framework::DefaultGradOpDescMaker<true>); \
512530
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
513531

514-
#define REGISTER_ACTIVATION_MKLDNN_OP(OP_NAME, KERNEL_TYPE) \
515-
REGISTER_OPERATOR(KERNEL_TYPE, ops::ActivationWithMKLDNNOp, \
516-
::paddle::operators::OP_NAME##OpMaker, \
517-
::paddle::framework::DefaultGradOpDescMaker<true>); \
518-
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationWithMKLDNNOpGrad)
519-
520532
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
521533
REGISTER_OP_CPU_KERNEL( \
522534
act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
@@ -531,7 +543,5 @@ namespace ops = paddle::operators;
531543
ops::grad_functor<double>>);
532544

533545
FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP);
534-
FOR_EACH_MKLDNN_OP_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_OP);
535546
FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP);
536-
FOR_EACH_MKLDNN_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_MKLDNN_OP);
537547
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);

paddle/fluid/operators/mkldnn_activation_op.h

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -62,52 +62,5 @@ class MKLDNNActivationGradKernel
6262
}
6363
};
6464

65-
namespace { // NOLINT
66-
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
67-
const framework::OperatorWithKernel& oper,
68-
const std::string& name) {
69-
framework::LibraryType library{framework::LibraryType::kPlain};
70-
#ifdef PADDLE_WITH_MKLDNN
71-
if (library == framework::LibraryType::kPlain &&
72-
platform::CanMKLDNNBeUsed(ctx)) {
73-
library = framework::LibraryType::kMKLDNN;
74-
}
75-
#endif
76-
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
77-
return framework::OpKernelType(
78-
framework::ToDataType(ctx.Input<framework::Tensor>(name)->type()),
79-
ctx.GetPlace(), layout, library);
80-
}
81-
} // anonymous namespace
82-
83-
class ActivationWithMKLDNNOp : public framework::OperatorWithKernel {
84-
public:
85-
using framework::OperatorWithKernel::OperatorWithKernel;
86-
87-
void InferShape(framework::InferShapeContext* ctx) const override {
88-
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
89-
ctx->ShareLoD("X", /*->*/ "Out");
90-
}
91-
92-
framework::OpKernelType GetExpectedKernelType(
93-
const framework::ExecutionContext& ctx) const override {
94-
return GetKernelType(ctx, *this, "X");
95-
}
96-
};
97-
98-
class ActivationWithMKLDNNOpGrad : public framework::OperatorWithKernel {
99-
public:
100-
using framework::OperatorWithKernel::OperatorWithKernel;
101-
102-
void InferShape(framework::InferShapeContext* ctx) const override {
103-
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
104-
}
105-
106-
framework::OpKernelType GetExpectedKernelType(
107-
const framework::ExecutionContext& ctx) const override {
108-
return GetKernelType(ctx, *this, "Out");
109-
}
110-
};
111-
11265
} // namespace operators
11366
} // namespace paddle

0 commit comments

Comments
 (0)