Skip to content

Commit 60647c9

Browse files
authored
Merge pull request #11519 from jczaja/prv-softmax-mkldnn-grad-operator
MKLDNN: SoftmaxGrad Op
2 parents 3d1afe2 + 98f3ad3 commit 60647c9

File tree

4 files changed

+318
-55
lines changed

4 files changed

+318
-55
lines changed

cmake/external/mkldnn.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ ExternalProject_Add(
5454
${EXTERNAL_PROJECT_LOG_ARGS}
5555
DEPENDS ${MKLDNN_DEPENDS}
5656
GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git"
57-
GIT_TAG "db3424ad44901513c03a1ea31ccaacdf633fbe9f"
57+
GIT_TAG "a29d8487a63afca3d5b8c5bbdbb473cf8ccc6e51"
5858
PREFIX ${MKLDNN_SOURCES_DIR}
5959
UPDATE_COMMAND ""
6060
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR}

paddle/fluid/operators/softmax_mkldnn_op.cc

Lines changed: 167 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,81 @@ using paddle::platform::MKLDNNMemDesc;
2727
using mkldnn::memory; // Note: paddle has also "memory" namespace
2828
using mkldnn::primitive;
2929
using mkldnn::softmax_forward;
30+
using mkldnn::softmax_backward;
3031
using mkldnn::prop_kind;
3132
using mkldnn::stream;
33+
using platform::to_void_cast;
34+
35+
class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
36+
public:
37+
SoftmaxMKLDNNHandler(
38+
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> softmax_pd,
39+
const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
40+
const std::string& base_key)
41+
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
42+
softmax_pd_(softmax_pd) {}
43+
44+
SoftmaxMKLDNNHandler(
45+
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> softmax_pd,
46+
std::shared_ptr<mkldnn::softmax_backward::primitive_desc> softmax_bwd_pd,
47+
const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
48+
const std::string& base_key)
49+
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
50+
softmax_pd_(softmax_pd),
51+
softmax_bwd_pd_(softmax_bwd_pd) {
52+
// If we are in Grad operatgor then update a key with BWD suffix to
53+
// distinguish from FWD memory primitives
54+
key_ += "-BWD";
55+
}
56+
57+
std::shared_ptr<mkldnn::softmax_forward> AcquireSoftmax(
58+
std::shared_ptr<mkldnn::memory> dst_memory_p,
59+
std::shared_ptr<mkldnn::memory> src_memory_p) {
60+
/*Generate key*/
61+
auto prim_key = key_ + "@softmax_p";
62+
63+
auto softmax_p = std::static_pointer_cast<mkldnn::softmax_forward>(
64+
dev_ctx_.GetBlob(prim_key));
65+
PADDLE_ENFORCE((softmax_p != nullptr) || (is_reusing_ == false),
66+
"Fail to find softmax primitive in device context");
67+
if (softmax_p == nullptr) {
68+
softmax_p = std::make_shared<mkldnn::softmax_forward>(
69+
*(softmax_pd_.get()),
70+
*(static_cast<mkldnn::memory*>(src_memory_p.get())),
71+
*(static_cast<mkldnn::memory*>(dst_memory_p.get())));
72+
dev_ctx_.SetBlob(prim_key, softmax_p);
73+
} else {
74+
is_reusing_ = true;
75+
}
76+
77+
return softmax_p;
78+
}
79+
80+
std::shared_ptr<mkldnn::softmax_backward> AcquireSoftmaxBackward(
81+
std::shared_ptr<mkldnn::memory> dst_memory_p,
82+
std::shared_ptr<mkldnn::memory> diff_dst_memory_p,
83+
std::shared_ptr<mkldnn::memory> diff_src_memory_p) {
84+
auto prim_key = key_ + "@softmax_bwd_p";
85+
auto softmax_bwd_p = std::static_pointer_cast<mkldnn::softmax_backward>(
86+
dev_ctx_.GetBlob(prim_key));
87+
PADDLE_ENFORCE((softmax_bwd_p != nullptr) || (is_reusing_ == false),
88+
"Fail to find softmax backward primitive in device context");
89+
if (softmax_bwd_p == nullptr) {
90+
softmax_bwd_p = std::make_shared<mkldnn::softmax_backward>(
91+
*softmax_bwd_pd_, *(dst_memory_p.get()), *(diff_dst_memory_p.get()),
92+
*(diff_src_memory_p.get()));
93+
dev_ctx_.SetBlob(prim_key, softmax_bwd_p);
94+
} else {
95+
is_reusing_ = true;
96+
}
97+
98+
return softmax_bwd_p;
99+
}
100+
101+
private:
102+
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> softmax_pd_;
103+
std::shared_ptr<mkldnn::softmax_backward::primitive_desc> softmax_bwd_pd_;
104+
};
32105

33106
template <typename T>
34107
class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
@@ -54,56 +127,27 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
54127
// Same memory descriptor to be used for input and output
55128
memory::dims softmax_tz = {src_tz[0], src_tz[1]};
56129
// Generate keys for storing/retriving primitives for this operator
57-
// TODO(jczaja): Each MKLDNN operator may have diffrent hashing function
58-
auto gethash = [](memory::dims& operand_dims) {
59-
return std::string(std::to_string(operand_dims[0]) + "-" +
60-
std::to_string(operand_dims[1]));
61-
};
62-
const std::string key = gethash(softmax_tz);
63-
const std::string key_softmax_p = key + "@softmax_p";
64-
const std::string key_softmax_src_mem_p = key + "@softmax_src_mem_p";
65-
const std::string key_softmax_dst_mem_p = key + "@softmax_dst_mem_p";
66-
67-
std::shared_ptr<void> softmax_p = dev_ctx.GetBlob(key_softmax_p);
68-
if (softmax_p == nullptr) {
69-
// Currently only NC data format is supported
70-
auto softmax_md =
71-
MKLDNNMemDesc({softmax_tz}, memory::f32, memory::format::nc);
72-
// Normalization is made after innermost dimension eg. C out of NC
73-
auto softmax_desc = softmax_forward::desc(prop_kind::forward_scoring,
74-
softmax_md, 1 /*dim: C*/);
75-
// create memory primitives
76-
auto softmax_src_memory_p = std::make_shared<memory>(
77-
memory::primitive_desc{softmax_md, mkldnn_engine},
78-
static_cast<void*>(const_cast<T*>(input_data)));
79-
dev_ctx.SetBlob(key_softmax_src_mem_p, softmax_src_memory_p);
80-
auto softmax_dst_memory_p = std::make_shared<memory>(
81-
memory::primitive_desc{softmax_md, mkldnn_engine},
82-
static_cast<void*>(output_data));
83-
dev_ctx.SetBlob(key_softmax_dst_mem_p, softmax_dst_memory_p);
84-
85-
auto softmax_forward_pd =
86-
std::make_shared<softmax_forward::primitive_desc>(softmax_desc,
87-
mkldnn_engine);
88-
softmax_p = std::make_shared<softmax_forward>(
89-
*(softmax_forward_pd.get()),
90-
*(static_cast<memory*>(softmax_src_memory_p.get())),
91-
*(static_cast<memory*>(softmax_dst_memory_p.get())));
92-
dev_ctx.SetBlob(key_softmax_p, softmax_p);
93-
} else {
94-
// Primitives already exist
95-
auto src_memory_p = std::static_pointer_cast<memory>(
96-
dev_ctx.GetBlob(key_softmax_src_mem_p));
97-
PADDLE_ENFORCE(src_memory_p != nullptr,
98-
"Fail to find softmax src mem_p in device context");
99-
auto dst_memory_p = std::static_pointer_cast<memory>(
100-
dev_ctx.GetBlob(key_softmax_dst_mem_p));
101-
PADDLE_ENFORCE(dst_memory_p != nullptr,
102-
"Fail to find softmax dst mem_p in device context");
103-
src_memory_p->set_data_handle(
104-
reinterpret_cast<void*>(const_cast<T*>(input_data)));
105-
dst_memory_p->set_data_handle(output_data);
106-
}
130+
const std::string key =
131+
platform::MKLDNNHandler::GetHash(softmax_tz, ctx.op().Output("Out"));
132+
const std::string key_softmax_pd = key + "@softmax_pd";
133+
134+
// Currently only NC data format is supported
135+
auto softmax_md = MKLDNNMemDesc(
136+
{softmax_tz}, platform::MKLDNNGetDataType<T>(), memory::format::nc);
137+
// Normalization is made after innermost dimension eg. C out of NC
138+
auto softmax_desc = softmax_forward::desc(prop_kind::forward_scoring,
139+
softmax_md, 1 /*dim: C*/);
140+
auto softmax_pd = std::make_shared<mkldnn::softmax_forward::primitive_desc>(
141+
softmax_desc, mkldnn_engine);
142+
dev_ctx.SetBlob(key_softmax_pd, softmax_pd);
143+
144+
SoftmaxMKLDNNHandler handler(softmax_pd, dev_ctx, mkldnn_engine, key);
145+
auto softmax_src_memory_p =
146+
handler.AcquireSrcMemory(softmax_md, to_void_cast<T>(input_data));
147+
auto softmax_dst_memory_p =
148+
handler.AcquireDstMemory(softmax_md, to_void_cast<T>(output_data));
149+
auto softmax_p =
150+
handler.AcquireSoftmax(softmax_dst_memory_p, softmax_src_memory_p);
107151

108152
std::vector<primitive> pipeline{
109153
*(static_cast<softmax_forward::primitive*>(softmax_p.get()))};
@@ -120,10 +164,83 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
120164
}
121165
};
122166

167+
template <typename T>
168+
class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
169+
public:
170+
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
171+
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
172+
"It must use CPUPlace.");
173+
174+
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
175+
auto mkldnn_engine = dev_ctx.GetEngine();
176+
const Tensor* output = ctx.Input<Tensor>("Out");
177+
const T* dst_data = output->data<T>();
178+
179+
auto* dout = ctx.template Input<Tensor>(framework::GradVarName("Out"));
180+
const auto* diff_dst_ptr = dout->template data<T>();
181+
182+
auto* dx =
183+
ctx.template Output<framework::Tensor>(framework::GradVarName("X"));
184+
T* diff_src_ptr = dx->template mutable_data<T>(ctx.GetPlace());
185+
186+
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
187+
std::vector<int> src_tz(dst_tz);
188+
PADDLE_ENFORCE(output->dims().size() == 2UL,
189+
"The input of softmax op must be a 2D matrix.");
190+
// MKL-DNN does support softmax over selected axis. Having 2D Tensor,
191+
// we will make normalization after final eg. axis: 1
192+
PADDLE_ENFORCE(((src_tz[0] == dst_tz[0]) && (src_tz[1] == dst_tz[1])),
193+
"Softmax input and output dimensions should match");
194+
// Same memory descriptor to be used for input and output
195+
memory::dims softmax_tz = {src_tz[0], src_tz[1]};
196+
// Currently only supports NC data format
197+
// retrieve eltwise primitive desc from device context
198+
const std::string key =
199+
platform::MKLDNNHandler::GetHash(softmax_tz, ctx.op().Input("Out"));
200+
const std::string key_softmax_pd = key + "@softmax_pd";
201+
202+
auto softmax_pd =
203+
std::static_pointer_cast<mkldnn::softmax_forward::primitive_desc>(
204+
dev_ctx.GetBlob(key_softmax_pd));
205+
PADDLE_ENFORCE(softmax_pd != nullptr,
206+
"Fail to find softmax_pd in device context");
207+
208+
// TODO(jczaja): Add layouts support when there is a need to do so
209+
// Two dimensional softmax does support NC format
210+
auto data_softmax_md = MKLDNNMemDesc(
211+
{softmax_tz}, platform::MKLDNNGetDataType<T>(), memory::format::nc);
212+
auto diff_softmax_md = MKLDNNMemDesc(
213+
{softmax_tz}, platform::MKLDNNGetDataType<T>(), memory::format::nc);
214+
// Normalization is made after innermost dimension eg. C out of NC
215+
auto softmax_bwd_desc =
216+
softmax_backward::desc(diff_softmax_md, data_softmax_md, 1 /* dim: C*/);
217+
auto softmax_bwd_pd =
218+
std::make_shared<mkldnn::softmax_backward::primitive_desc>(
219+
softmax_bwd_desc, mkldnn_engine, *softmax_pd);
220+
221+
SoftmaxMKLDNNHandler handler(softmax_pd, softmax_bwd_pd, dev_ctx,
222+
mkldnn_engine, key);
223+
auto dst_memory_p =
224+
handler.AcquireDstMemory(data_softmax_md, to_void_cast<T>(dst_data));
225+
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(
226+
diff_softmax_md, to_void_cast<T>(diff_dst_ptr));
227+
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(
228+
diff_softmax_md, to_void_cast<T>(diff_src_ptr));
229+
230+
// Get primitve from device context
231+
auto softmax_bwd_p = handler.AcquireSoftmaxBackward(
232+
dst_memory_p, diff_dst_memory_p, diff_src_memory_p);
233+
234+
std::vector<primitive> pipeline{*softmax_bwd_p};
235+
stream(stream::kind::eager).submit(pipeline).wait();
236+
}
237+
};
123238
} // namespace operators
124239
} // namespace paddle
125240

126241
namespace ops = paddle::operators;
127242

128243
REGISTER_OP_KERNEL(softmax, MKLDNN, ::paddle::platform::CPUPlace,
129244
ops::SoftmaxMKLDNNKernel<float>);
245+
REGISTER_OP_KERNEL(softmax_grad, MKLDNN, ::paddle::platform::CPUPlace,
246+
ops::SoftmaxMKLDNNGradKernel<float>);

paddle/fluid/operators/softmax_op.cc

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,16 +145,30 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
145145
const framework::ExecutionContext& ctx) const override {
146146
// choose cudnn kernel if the runtime supported.
147147
framework::LibraryType library_{framework::LibraryType::kPlain};
148+
std::string data_format = ctx.Attr<std::string>("data_format");
149+
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
148150

149151
#ifdef PADDLE_WITH_CUDA
150152
if (platform::CanCUDNNBeUsed(ctx)) {
151153
library_ = framework::LibraryType::kCUDNN;
152154
}
153155
#endif
154-
std::string data_format = ctx.Attr<std::string>("data_format");
155-
return framework::OpKernelType(
156-
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
157-
framework::StringToDataLayout(data_format), library_);
156+
#ifdef PADDLE_WITH_MKLDNN
157+
if (library_ == framework::LibraryType::kPlain &&
158+
platform::CanMKLDNNBeUsed(ctx)) {
159+
library_ = framework::LibraryType::kMKLDNN;
160+
layout_ = framework::DataLayout::kMKLDNN;
161+
}
162+
#endif
163+
auto input_data_type =
164+
framework::ToDataType(ctx.Input<Tensor>("X")->type());
165+
if (input_data_type == framework::proto::VarType::FP16) {
166+
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
167+
"float16 can only be used on GPU place");
168+
}
169+
170+
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
171+
library_);
158172
}
159173
};
160174

0 commit comments

Comments
 (0)