@@ -41,7 +41,7 @@ namespace operators {
41
41
\
42
42
protected: \
43
43
std::unique_ptr<::paddle::framework::OpDesc> Apply () const override { \
44
- auto * op = new ::paddle::framework::OpDesc (); \
44
+ auto * op = new ::paddle::framework::OpDesc (); \
45
45
op->SetType (#KERNEL_TYPE " _grad" ); \
46
46
op->SetInput (" Out" , Output (" Out" )); \
47
47
op->SetInput (::paddle::framework::GradVarName (" Out" ), \
@@ -54,23 +54,50 @@ namespace operators {
54
54
} \
55
55
}
56
56
57
+ framework::OpKernelType GetKernelType (const framework::ExecutionContext& ctx,
58
+ const framework::OperatorWithKernel& oper,
59
+ const std::string& name) {
60
+ framework::LibraryType library{framework::LibraryType::kPlain };
61
+ #ifdef PADDLE_WITH_MKLDNN
62
+ auto it = oper.Attrs ().find (" use_mkldnn" );
63
+ if (library == framework::LibraryType::kPlain && it != oper.Attrs ().end () &&
64
+ platform::CanMKLDNNBeUsed (ctx)) {
65
+ library = framework::LibraryType::kMKLDNN ;
66
+ }
67
+ #endif
68
+ framework::DataLayout layout = framework::DataLayout::kAnyLayout ;
69
+ return framework::OpKernelType (
70
+ framework::ToDataType (ctx.Input <framework::Tensor>(name)->type ()),
71
+ ctx.GetPlace (), layout, library);
72
+ }
73
+
57
74
class ActivationOp : public framework ::OperatorWithKernel {
58
75
public:
59
76
using framework::OperatorWithKernel::OperatorWithKernel;
60
77
61
- void InferShape (framework::InferShapeContext * ctx) const override {
78
+ void InferShape (framework::InferShapeContext* ctx) const override {
62
79
ctx->SetOutputDim (" Out" , ctx->GetInputDim (" X" ));
63
80
ctx->ShareLoD (" X" , /* ->*/ " Out" );
64
81
}
82
+
83
+ framework::OpKernelType GetExpectedKernelType (
84
+ const framework::ExecutionContext& ctx) const override {
85
+ return GetKernelType (ctx, *this , " X" );
86
+ }
65
87
};
66
88
67
89
class ActivationOpGrad : public framework ::OperatorWithKernel {
68
90
public:
69
91
using framework::OperatorWithKernel::OperatorWithKernel;
70
92
71
- void InferShape (framework::InferShapeContext * ctx) const override {
93
+ void InferShape (framework::InferShapeContext* ctx) const override {
72
94
ctx->SetOutputDim (framework::GradVarName (" X" ), ctx->GetInputDim (" Out" ));
73
95
}
96
+
97
+ framework::OpKernelType GetExpectedKernelType (
98
+ const framework::ExecutionContext& ctx) const override {
99
+ return GetKernelType (ctx, *this , " Out" );
100
+ }
74
101
};
75
102
76
103
__attribute__ ((unused)) constexpr char SigmoidDoc[] = R"DOC(
@@ -458,22 +485,21 @@ namespace ops = paddle::operators;
458
485
459
486
#define FOR_EACH_INPLACE_OP_FUNCTOR (__macro ) \
460
487
__macro (Sigmoid, sigmoid); \
488
+ __macro (Relu, relu); \
461
489
__macro (Exp, exp); \
490
+ __macro (Tanh, tanh); \
462
491
__macro (Ceil, ceil); \
463
492
__macro (Floor, floor); \
493
+ __macro (Sqrt, sqrt); \
464
494
__macro (SoftRelu, soft_relu); \
465
495
__macro (Relu6, relu6); \
466
496
__macro (Reciprocal, reciprocal); \
467
497
__macro (HardSigmoid, hard_sigmoid);
468
498
469
- #define FOR_EACH_MKLDNN_INPLACE_OP_FUNCTOR (__macro ) \
470
- __macro (Relu, relu); \
471
- __macro (Tanh, tanh); \
472
- __macro (Sqrt, sqrt);
473
-
474
499
#define FOR_EACH_OP_FUNCTOR (__macro ) \
475
500
__macro (LogSigmoid, logsigmoid); \
476
501
__macro (SoftShrink, softshrink); \
502
+ __macro (Abs, abs); \
477
503
__macro (Cos, cos); \
478
504
__macro (Sin, sin); \
479
505
__macro (Round, round); \
@@ -491,32 +517,18 @@ namespace ops = paddle::operators;
491
517
__macro (Swish, swish); \
492
518
__macro (ThresholdedRelu, thresholded_relu);
493
519
494
- #define FOR_EACH_MKLDNN_OP_FUNCTOR (__macro ) __macro(Abs, abs);
495
-
496
520
#define REGISTER_INPLACE_ACTIVATION_OP (OP_NAME, KERNEL_TYPE ) \
497
521
REGISTER_OPERATOR (KERNEL_TYPE, ::paddle::operators::ActivationOp, \
498
522
::paddle::operators::OP_NAME##OpMaker, \
499
523
::paddle::operators::OP_NAME##GradMaker); \
500
524
REGISTER_OPERATOR (KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
501
525
502
- #define REGISTER_INPLACE_ACTIVATION_MKLDNN_OP (OP_NAME, KERNEL_TYPE ) \
503
- REGISTER_OPERATOR (KERNEL_TYPE, ops::ActivationWithMKLDNNOp, \
504
- ::paddle::operators::OP_NAME##OpMaker, \
505
- ::paddle::operators::OP_NAME##GradMaker); \
506
- REGISTER_OPERATOR (KERNEL_TYPE##_grad, ops::ActivationWithMKLDNNOpGrad)
507
-
508
526
#define REGISTER_ACTIVATION_OP (OP_NAME, KERNEL_TYPE ) \
509
527
REGISTER_OPERATOR (KERNEL_TYPE, ::paddle::operators::ActivationOp, \
510
528
::paddle::operators::OP_NAME##OpMaker, \
511
529
::paddle::framework::DefaultGradOpDescMaker<true >); \
512
530
REGISTER_OPERATOR (KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
513
531
514
- #define REGISTER_ACTIVATION_MKLDNN_OP (OP_NAME, KERNEL_TYPE ) \
515
- REGISTER_OPERATOR (KERNEL_TYPE, ops::ActivationWithMKLDNNOp, \
516
- ::paddle::operators::OP_NAME##OpMaker, \
517
- ::paddle::framework::DefaultGradOpDescMaker<true >); \
518
- REGISTER_OPERATOR (KERNEL_TYPE##_grad, ops::ActivationWithMKLDNNOpGrad)
519
-
520
532
#define REGISTER_ACTIVATION_CPU_KERNEL (act_type, functor, grad_functor ) \
521
533
REGISTER_OP_CPU_KERNEL ( \
522
534
act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
@@ -531,7 +543,5 @@ namespace ops = paddle::operators;
531
543
ops::grad_functor<double >>);
532
544
533
545
FOR_EACH_OP_FUNCTOR (REGISTER_ACTIVATION_OP);
534
- FOR_EACH_MKLDNN_OP_FUNCTOR (REGISTER_ACTIVATION_MKLDNN_OP);
535
546
FOR_EACH_INPLACE_OP_FUNCTOR (REGISTER_INPLACE_ACTIVATION_OP);
536
- FOR_EACH_MKLDNN_INPLACE_OP_FUNCTOR (REGISTER_INPLACE_ACTIVATION_MKLDNN_OP);
537
547
FOR_EACH_KERNEL_FUNCTOR (REGISTER_ACTIVATION_CPU_KERNEL);
0 commit comments