Skip to content

Commit 7205d33

Browse files
authored
Merge pull request #10597 from kbinias/mkldnn-activations-improvments
Update activations for MKL-DNN
2 parents 2a77fc5 + 24904b9 commit 7205d33

File tree

4 files changed

+172
-117
lines changed

4 files changed

+172
-117
lines changed

paddle/fluid/operators/activation_mkldnn_op.cc

Lines changed: 135 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mkldnn.hpp"
1616
#include "paddle/fluid/operators/activation_op.h"
1717
#include "paddle/fluid/operators/mkldnn_activation_op.h"
18+
#include "paddle/fluid/platform/mkldnn_helper.h"
1819

1920
namespace paddle {
2021
namespace operators {
@@ -23,6 +24,18 @@ using paddle::framework::Tensor;
2324
using paddle::platform::MKLDNNDeviceContext;
2425

2526
namespace {
27+
std::string gethash(const mkldnn::memory::dims &operand_dims,
28+
const mkldnn::algorithm algorithm) {
29+
auto dim2str = [](const mkldnn::memory::dims &operand_dims) {
30+
std::string dstr = "";
31+
for (size_t i = 0; i < operand_dims.size(); ++i) {
32+
dstr += std::to_string(operand_dims[i]) + "-";
33+
}
34+
return dstr;
35+
};
36+
return dim2str(operand_dims) + std::to_string(algorithm);
37+
}
38+
2639
template <typename T, typename ExecContext>
2740
void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
2841
const T alpha = 0, const T beta = 0) {
@@ -37,42 +50,70 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
3750
const auto *src_data = src->template data<T>();
3851

3952
auto *dst = ctx.template Output<Tensor>("Out");
40-
const T *dst_data = dst->template mutable_data<T>(ctx.GetPlace());
53+
T *dst_data = dst->template mutable_data<T>(ctx.GetPlace());
4154

4255
// get memory dim
4356
PADDLE_ENFORCE(src->dims().size() == 2 || src->dims().size() == 4,
4457
"Input dim must be with 2 or 4");
4558
std::vector<int> src_tz = framework::vectorize2int(src->dims());
4659

47-
// create memory description
48-
auto data_md = src_tz.size() == 2
49-
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
50-
mkldnn::memory::format::nc)
51-
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
52-
mkldnn::memory::format::nchw);
53-
54-
// create memory primitives
55-
auto src_memory =
56-
mkldnn::memory({data_md, mkldnn_engine},
57-
static_cast<void *>(const_cast<float *>(src_data)));
58-
auto dst_memory =
59-
mkldnn::memory({data_md, mkldnn_engine},
60-
static_cast<void *>(const_cast<float *>(dst_data)));
61-
62-
auto forward_desc = mkldnn::eltwise_forward::desc(
63-
mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta);
64-
65-
// save prim desc into global device context to be referred in backward path
66-
const std::string key = ctx.op().Output("Out");
67-
const std::string key_eltwise_pd = key + "@eltwise_pd";
68-
auto forward_pd = std::make_shared<mkldnn::eltwise_forward::primitive_desc>(
69-
forward_desc, mkldnn_engine);
70-
dev_ctx.SetBlob(key_eltwise_pd, forward_pd);
71-
72-
auto eltwise = mkldnn::eltwise_forward(*forward_pd, src_memory, dst_memory);
60+
const std::string key = gethash(src_tz, algorithm);
61+
const std::string key_src_data =
62+
key + ctx.op().Output("Out") + "@eltwise_fwd_src_data";
63+
const std::string key_src_mem = key + "@eltwise_fwd_src_mem";
64+
const std::string key_dst_mem = key + "@eltwise_fwd_dst_mem";
65+
const std::string key_fwd = key + "@eltwise_fwd";
66+
67+
auto p_fwd = std::static_pointer_cast<mkldnn::eltwise_forward>(
68+
dev_ctx.GetBlob(key_fwd));
69+
70+
// save input data to be referred in backward path
71+
auto p_src_data = std::make_shared<const T *>(src_data);
72+
dev_ctx.SetBlob(key_src_data, p_src_data);
73+
74+
if (p_fwd == nullptr) {
75+
// create memory description
76+
auto data_md = src_tz.size() == 2
77+
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
78+
mkldnn::memory::format::nc)
79+
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
80+
mkldnn::memory::format::nchw);
81+
82+
// create memory primitives
83+
auto p_src_mem = std::make_shared<mkldnn::memory>(mkldnn::memory(
84+
{data_md, mkldnn_engine}, platform::to_void_cast(src_data)));
85+
dev_ctx.SetBlob(key_src_mem, p_src_mem);
86+
87+
auto p_dst_mem = std::make_shared<mkldnn::memory>(mkldnn::memory(
88+
{data_md, mkldnn_engine}, platform::to_void_cast(dst_data)));
89+
dev_ctx.SetBlob(key_dst_mem, p_dst_mem);
90+
91+
auto fwd_desc = mkldnn::eltwise_forward::desc(
92+
mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta);
93+
auto p_fwd_pd = std::make_shared<mkldnn::eltwise_forward::primitive_desc>(
94+
fwd_desc, mkldnn_engine);
95+
const std::string key_fwd_pd = key + "eltwise_fwd_pd";
96+
dev_ctx.SetBlob(key_fwd_pd, p_fwd_pd);
97+
p_fwd = std::make_shared<mkldnn::eltwise_forward>(
98+
*p_fwd_pd, *(p_src_mem.get()), *(p_dst_mem.get()));
99+
dev_ctx.SetBlob(key_fwd, p_fwd);
100+
} else {
101+
// primitives already exist
102+
auto p_src_mem =
103+
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem));
104+
PADDLE_ENFORCE(p_src_mem != nullptr,
105+
"Fail to find eltwise p_src_mem in device context.");
106+
auto p_dst_mem =
107+
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_dst_mem));
108+
PADDLE_ENFORCE(p_dst_mem != nullptr,
109+
"Fail to find eltwise p_src_mem in device context.");
110+
111+
p_src_mem->set_data_handle(platform::to_void_reinterpret_cast(src_data));
112+
p_dst_mem->set_data_handle(dst_data);
113+
}
73114

74115
// push primitive to stream and wait until it's executed
75-
std::vector<mkldnn::primitive> pipeline = {eltwise};
116+
std::vector<mkldnn::primitive> pipeline = {*(p_fwd.get())};
76117
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
77118
}
78119

@@ -83,8 +124,7 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
83124
const auto &mkldnn_engine = dev_ctx.GetEngine();
84125

85126
// get buffers
86-
const auto *x = ctx.template Input<Tensor>("X");
87-
const auto *src = x->template data<T>();
127+
const auto *out = ctx.template Input<Tensor>("Out");
88128

89129
auto *dout = ctx.template Input<Tensor>(framework::GradVarName("Out"));
90130
const auto *diff_dst = dout->template data<T>();
@@ -94,45 +134,73 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
94134
const T *diff_src = dx->template mutable_data<T>(ctx.GetPlace());
95135

96136
// get memory dim
97-
std::vector<int> src_tz = framework::vectorize2int(x->dims());
98-
99-
// create memory description
100-
auto data_md = src_tz.size() == 2
101-
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
102-
mkldnn::memory::format::nc)
103-
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
104-
mkldnn::memory::format::nchw);
105-
106-
// create memory primitives
107-
auto src_memory = mkldnn::memory(
108-
{data_md, mkldnn_engine}, static_cast<void *>(const_cast<float *>(src)));
109-
auto diff_src_memory =
110-
mkldnn::memory({data_md, mkldnn_engine},
111-
static_cast<void *>(const_cast<float *>(diff_src)));
112-
auto diff_dst_memory =
113-
mkldnn::memory({data_md, mkldnn_engine},
114-
static_cast<void *>(const_cast<float *>(diff_dst)));
115-
116-
auto backward_desc =
117-
mkldnn::eltwise_backward::desc(algorithm, data_md, data_md, alpha, beta);
118-
119-
// retrieve eltwise primitive desc from device context
120-
const std::string key = ctx.op().Input("Out");
121-
const std::string key_eltwise_pd = key + "@eltwise_pd";
122-
const std::shared_ptr<void> forward_pd = dev_ctx.GetBlob(key_eltwise_pd);
123-
PADDLE_ENFORCE(forward_pd != nullptr,
124-
"Fail to find eltwise_pd in device context");
125-
auto *p_forward_pd =
126-
static_cast<mkldnn::eltwise_forward::primitive_desc *>(forward_pd.get());
127-
128-
auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc(
129-
backward_desc, mkldnn_engine, *p_forward_pd);
130-
131-
auto eltwise_bwd = mkldnn::eltwise_backward(eltwise_bwd_prim_desc, src_memory,
132-
diff_dst_memory, diff_src_memory);
137+
std::vector<int> src_tz = framework::vectorize2int(out->dims());
138+
139+
const std::string key = gethash(src_tz, algorithm);
140+
const std::string key_diff_src_mem = key + "@eltwise_diff_src_mem";
141+
const std::string key_diff_dst_mem = key + "@eltwise_diff_dst_mem";
142+
const std::string key_grad = key + "@eltwise_grad";
143+
144+
const std::string key_src_data =
145+
key + ctx.op().Input("Out") + "@eltwise_fwd_src_data";
146+
const auto p_src_data =
147+
std::static_pointer_cast<T *>(dev_ctx.GetBlob(key_src_data));
148+
149+
const std::string key_src_mem = key + "@eltwise_fwd_src_mem";
150+
auto p_src_mem =
151+
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem));
152+
p_src_mem->set_data_handle(*p_src_data.get());
153+
154+
auto p_grad = std::static_pointer_cast<mkldnn::eltwise_forward::primitive>(
155+
dev_ctx.GetBlob(key_grad));
156+
157+
if (p_grad == nullptr) {
158+
// create memory description
159+
auto data_md = src_tz.size() == 2
160+
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
161+
mkldnn::memory::format::nc)
162+
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
163+
mkldnn::memory::format::nchw);
164+
165+
// create memory primitives
166+
std::shared_ptr<void> p_diff_src_mem =
167+
std::make_shared<mkldnn::memory>(mkldnn::memory(
168+
{data_md, mkldnn_engine}, platform::to_void_cast(diff_src)));
169+
dev_ctx.SetBlob(key_diff_src_mem, p_diff_src_mem);
170+
std::shared_ptr<void> p_diff_dst_mem =
171+
std::make_shared<mkldnn::memory>(mkldnn::memory(
172+
{data_md, mkldnn_engine}, platform::to_void_cast(diff_dst)));
173+
dev_ctx.SetBlob(key_diff_dst_mem, p_diff_dst_mem);
174+
175+
auto bwd_desc = mkldnn::eltwise_backward::desc(algorithm, data_md, data_md,
176+
alpha, beta);
177+
178+
const std::string key_fwd_pd = key + "eltwise_fwd_pd";
179+
auto *p_fwd_pd = static_cast<mkldnn::eltwise_forward::primitive_desc *>(
180+
dev_ctx.GetBlob(key_fwd_pd).get());
181+
182+
auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc(
183+
bwd_desc, mkldnn_engine, *p_fwd_pd);
184+
185+
p_grad = std::make_shared<mkldnn::eltwise_backward>(
186+
eltwise_bwd_prim_desc, *static_cast<mkldnn::memory *>(p_src_mem.get()),
187+
*(static_cast<mkldnn::memory *>(p_diff_dst_mem.get())),
188+
*(static_cast<mkldnn::memory *>(p_diff_src_mem.get())));
189+
} else {
190+
// primitives already exist
191+
auto p_diff_src_mem = std::static_pointer_cast<mkldnn::memory>(
192+
dev_ctx.GetBlob(key_diff_src_mem));
193+
auto p_diff_dst_mem = std::static_pointer_cast<mkldnn::memory>(
194+
dev_ctx.GetBlob(key_diff_dst_mem));
195+
196+
p_diff_src_mem->set_data_handle(
197+
platform::to_void_reinterpret_cast(diff_src));
198+
p_diff_dst_mem->set_data_handle(
199+
platform::to_void_reinterpret_cast(diff_dst));
200+
}
133201

134202
// push primitive to stream and wait until it's executed
135-
std::vector<mkldnn::primitive> pipeline = {eltwise_bwd};
203+
std::vector<mkldnn::primitive> pipeline = {*(p_grad.get())};
136204
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
137205
}
138206
} // anonymous namespace

paddle/fluid/operators/activation_op.cc

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ namespace operators {
4141
\
4242
protected: \
4343
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { \
44-
auto *op = new ::paddle::framework::OpDesc(); \
44+
auto* op = new ::paddle::framework::OpDesc(); \
4545
op->SetType(#KERNEL_TYPE "_grad"); \
4646
op->SetInput("Out", Output("Out")); \
4747
op->SetInput(::paddle::framework::GradVarName("Out"), \
@@ -54,23 +54,50 @@ namespace operators {
5454
} \
5555
}
5656

57+
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
58+
const framework::OperatorWithKernel& oper,
59+
const std::string& name) {
60+
framework::LibraryType library{framework::LibraryType::kPlain};
61+
#ifdef PADDLE_WITH_MKLDNN
62+
auto it = oper.Attrs().find("use_mkldnn");
63+
if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
64+
platform::CanMKLDNNBeUsed(ctx)) {
65+
library = framework::LibraryType::kMKLDNN;
66+
}
67+
#endif
68+
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
69+
return framework::OpKernelType(
70+
framework::ToDataType(ctx.Input<framework::Tensor>(name)->type()),
71+
ctx.GetPlace(), layout, library);
72+
}
73+
5774
class ActivationOp : public framework::OperatorWithKernel {
5875
public:
5976
using framework::OperatorWithKernel::OperatorWithKernel;
6077

61-
void InferShape(framework::InferShapeContext *ctx) const override {
78+
void InferShape(framework::InferShapeContext* ctx) const override {
6279
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
6380
ctx->ShareLoD("X", /*->*/ "Out");
6481
}
82+
83+
framework::OpKernelType GetExpectedKernelType(
84+
const framework::ExecutionContext& ctx) const override {
85+
return GetKernelType(ctx, *this, "X");
86+
}
6587
};
6688

6789
class ActivationOpGrad : public framework::OperatorWithKernel {
6890
public:
6991
using framework::OperatorWithKernel::OperatorWithKernel;
7092

71-
void InferShape(framework::InferShapeContext *ctx) const override {
93+
void InferShape(framework::InferShapeContext* ctx) const override {
7294
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
7395
}
96+
97+
framework::OpKernelType GetExpectedKernelType(
98+
const framework::ExecutionContext& ctx) const override {
99+
return GetKernelType(ctx, *this, "Out");
100+
}
74101
};
75102

76103
__attribute__((unused)) constexpr char SigmoidDoc[] = R"DOC(

paddle/fluid/operators/mkldnn_activation_op.h

Lines changed: 2 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#pragma once
16+
#include <string>
17+
1618
#include "paddle/fluid/framework/eigen.h"
1719
#include "paddle/fluid/framework/op_registry.h"
1820
#include "paddle/fluid/operators/detail/safe_ref.h"
@@ -60,52 +62,5 @@ class MKLDNNActivationGradKernel
6062
}
6163
};
6264

63-
namespace { // NOLINT
64-
framework::OpKernelType GetKernelType(
65-
const framework::ExecutionContext& ctx,
66-
const framework::OperatorWithKernel& oper) {
67-
framework::LibraryType library{framework::LibraryType::kPlain};
68-
#ifdef PADDLE_WITH_MKLDNN
69-
if (library == framework::LibraryType::kPlain &&
70-
platform::CanMKLDNNBeUsed(ctx)) {
71-
library = framework::LibraryType::kMKLDNN;
72-
}
73-
#endif
74-
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
75-
return framework::OpKernelType(
76-
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
77-
ctx.GetPlace(), layout, library);
78-
}
79-
} // anonymous namespace
80-
81-
class ActivationWithMKLDNNOp : public framework::OperatorWithKernel {
82-
public:
83-
using framework::OperatorWithKernel::OperatorWithKernel;
84-
85-
void InferShape(framework::InferShapeContext* ctx) const override {
86-
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
87-
ctx->ShareLoD("X", /*->*/ "Out");
88-
}
89-
90-
framework::OpKernelType GetExpectedKernelType(
91-
const framework::ExecutionContext& ctx) const override {
92-
return GetKernelType(ctx, *this);
93-
}
94-
};
95-
96-
class ActivationWithMKLDNNOpGrad : public framework::OperatorWithKernel {
97-
public:
98-
using framework::OperatorWithKernel::OperatorWithKernel;
99-
100-
void InferShape(framework::InferShapeContext* ctx) const override {
101-
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
102-
}
103-
104-
framework::OpKernelType GetExpectedKernelType(
105-
const framework::ExecutionContext& ctx) const override {
106-
return GetKernelType(ctx, *this);
107-
}
108-
};
109-
11065
} // namespace operators
11166
} // namespace paddle

paddle/fluid/platform/mkldnn_helper.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ void* to_void_cast(const Type* t) {
3838
return static_cast<void*>(const_cast<Type*>(t));
3939
}
4040

41+
template <typename Type>
42+
void* to_void_reinterpret_cast(const Type* t) {
43+
return reinterpret_cast<void*>(const_cast<Type*>(t));
44+
}
45+
4146
template <class Type>
4247
using tf_desc = typename Type::desc;
4348

0 commit comments

Comments
 (0)