Skip to content

Commit a98a3fd

Browse files
authored
Merge pull request #9385 from mozga-intel/mozga/mkldnn-fc
Implementation of MKLDNN FC
2 parents d02a68c + 46e14bb commit a98a3fd

File tree

5 files changed

+608
-24
lines changed

5 files changed

+608
-24
lines changed
Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/tensor.h"
16+
#include "paddle/fluid/operators/fc_op.h"
17+
#include "paddle/fluid/platform/device_context.h"
18+
#include "paddle/fluid/platform/mkldnn_helper.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
using paddle::framework::Tensor;
24+
using paddle::platform::MKLDNNDeviceContext;
25+
26+
template <typename T>
27+
class MKLDNNMD {
28+
public:
29+
explicit MKLDNNMD(const T* in, const T* w, bool bias)
30+
: in{paddle::framework::vectorize2int(in->dims())},
31+
w{paddle::framework::vectorize2int(w->dims())} {
32+
with_bias_ = bias;
33+
}
34+
35+
mkldnn::memory::desc dst() const {
36+
return platform::MKLDNNMemDesc({in[0], w[1]},
37+
mkldnn::memory::data_type::f32,
38+
mkldnn::memory::format::nc);
39+
}
40+
41+
mkldnn::memory::desc src() const {
42+
return is_spatial()
43+
? platform::MKLDNNMemDesc({in[0], in[1], in[2], in[3]},
44+
mkldnn::memory::data_type::f32,
45+
mkldnn::memory::format::nchw)
46+
: platform::MKLDNNMemDesc({in[0], in[1]},
47+
mkldnn::memory::data_type::f32,
48+
mkldnn::memory::format::nc);
49+
}
50+
51+
mkldnn::memory::desc weights() const {
52+
return is_spatial()
53+
? platform::MKLDNNMemDesc({w[1], in[1], in[2], in[3]},
54+
mkldnn::memory::data_type::f32,
55+
mkldnn::memory::format::oihw)
56+
: platform::MKLDNNMemDesc({w[1], in[1]},
57+
mkldnn::memory::data_type::f32,
58+
mkldnn::memory::format::oi);
59+
}
60+
61+
mkldnn::memory::desc bias() const {
62+
return with_bias_
63+
? platform::MKLDNNMemDesc({w[1]}, mkldnn::memory::data_type::f32,
64+
mkldnn::memory::format::format_undef)
65+
: platform::MKLDNNMemDesc({}, mkldnn::memory::data_type::f32,
66+
mkldnn::memory::format::format_undef);
67+
}
68+
69+
private:
70+
bool is_spatial() const { return in.size() > 1 && w.size() > 1; }
71+
72+
std::vector<int> in;
73+
std::vector<int> w;
74+
bool with_bias_;
75+
bool is_spatial_;
76+
};
77+
78+
class MKLDNNMemory {
79+
public:
80+
MKLDNNMemory(MKLDNNMD<Tensor>* t, const mkldnn::engine& e)
81+
: md_{t}, engine_{e} {}
82+
virtual ~MKLDNNMemory() = default;
83+
84+
template <typename Output>
85+
mkldnn::memory dst(const Output* out) {
86+
return mkldnn::memory({md_->dst(), engine_},
87+
static_cast<void*>(const_cast<float*>(out)));
88+
}
89+
90+
template <typename Output>
91+
mkldnn::memory dst(Output* out) {
92+
return mkldnn::memory({md_->dst(), engine_}, out);
93+
}
94+
95+
template <typename Input>
96+
mkldnn::memory src(const Input* in) {
97+
return mkldnn::memory({md_->src(), engine_},
98+
static_cast<void*>(const_cast<float*>(in)));
99+
}
100+
101+
template <typename Weight>
102+
mkldnn::memory weights(const Weight* w) {
103+
return mkldnn::memory({md_->weights(), engine_},
104+
static_cast<void*>(const_cast<float*>(w)));
105+
}
106+
107+
mkldnn::memory bias() {
108+
return mkldnn::memory(mkldnn::memory::primitive_desc(md_->bias(), engine_));
109+
}
110+
111+
private:
112+
MKLDNNMD<Tensor>* md_;
113+
const mkldnn::engine& engine_;
114+
};
115+
116+
template <typename T>
117+
class FCMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
118+
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
119+
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
120+
"It must use CPUPlace.");
121+
122+
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
123+
const auto& mkldnn_engine = dev_ctx.GetEngine();
124+
125+
auto input = ctx.Input<Tensor>("Input");
126+
auto w = ctx.Input<Tensor>("W");
127+
128+
PADDLE_ENFORCE(input->dims().size() == 2 || input->dims().size() == 4,
129+
"Input must be with 2 or 4 dimensions, i.e. NCHW");
130+
PADDLE_ENFORCE(w->dims().size() == 2 || w->dims().size() == 4,
131+
"Weights must be with 2 or 4 dimensions, i.e. OI or OIHW");
132+
133+
bool with_bias = ctx.Attr<bool>("bias_attr");
134+
MKLDNNMD<Tensor> md(input, w, with_bias);
135+
136+
std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> pd =
137+
FcFwdPrimitiveDesc(md.src(), md.weights(), md.dst(), md.bias(),
138+
with_bias, mkldnn_engine);
139+
140+
const std::string key = ctx.op().Output("Out");
141+
const std::string key_fc_pd = key + "@fc_pd";
142+
143+
dev_ctx.SetBlob(key_fc_pd, pd);
144+
145+
MKLDNNMemory mem(&md, mkldnn_engine);
146+
147+
const T* input_data = input->data<T>();
148+
const T* w_data = w->data<T>();
149+
150+
auto output = ctx.Output<Tensor>("Out");
151+
T* output_data = output->mutable_data<T>(ctx.GetPlace());
152+
153+
auto dst_memory = mem.dst(output_data);
154+
auto src_memory = mem.src(input_data);
155+
auto weights_memory = mem.weights(w_data);
156+
auto bias_memory = mem.bias();
157+
158+
auto forward = with_bias ? mkldnn::inner_product_forward(
159+
*pd, src_memory, weights_memory, bias_memory,
160+
dst_memory)
161+
: mkldnn::inner_product_forward(
162+
*pd, src_memory, weights_memory, dst_memory);
163+
164+
std::vector<mkldnn::primitive> pipeline = {forward};
165+
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
166+
}
167+
168+
private:
169+
std::unique_ptr<mkldnn::inner_product_forward::primitive_desc>
170+
FcFwdPrimitiveDesc(const mkldnn::memory::desc& src,
171+
const mkldnn::memory::desc& weights,
172+
const mkldnn::memory::desc& dst,
173+
const mkldnn::memory::desc& bias, const bool with_bias,
174+
const mkldnn::engine& engine) const {
175+
auto desc = with_bias
176+
? mkldnn::inner_product_forward::desc(
177+
mkldnn::prop_kind::forward, src, weights, bias, dst)
178+
: mkldnn::inner_product_forward::desc(
179+
mkldnn::prop_kind::forward, src, weights, dst);
180+
181+
auto pd = new mkldnn::inner_product_forward::primitive_desc(desc, engine);
182+
return std::unique_ptr<mkldnn::inner_product_forward::primitive_desc>(pd);
183+
}
184+
};
185+
186+
template <typename T>
187+
class FCMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
188+
public:
189+
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
190+
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
191+
"It must use CPUPlace.");
192+
193+
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
194+
const auto& mkldnn_engine = dev_ctx.GetEngine();
195+
196+
T* input_grad_data = nullptr;
197+
T* w_grad_data = nullptr;
198+
199+
Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
200+
Tensor* w_grad = ctx.Output<Tensor>(framework::GradVarName("W"));
201+
202+
if (input_grad) {
203+
input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
204+
}
205+
if (w_grad) {
206+
w_grad_data = w_grad->mutable_data<T>(ctx.GetPlace());
207+
}
208+
209+
const Tensor* input = ctx.Input<Tensor>("Input");
210+
const T* input_data = input->data<T>();
211+
212+
const Tensor* w = ctx.Input<Tensor>("W");
213+
const T* w_data = w->data<T>();
214+
215+
const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
216+
const T* out_grad_data = out_grad->data<T>();
217+
218+
bool with_bias = ctx.Attr<bool>("bias_attr");
219+
220+
MKLDNNMD<Tensor> md(input, w, with_bias);
221+
MKLDNNMemory mem(&md, mkldnn_engine);
222+
223+
auto dst_memory = mem.dst(out_grad_data);
224+
auto src_memory = mem.src(input_data);
225+
auto weights_memory = mem.weights(w_data);
226+
auto bias_memory = mem.bias();
227+
228+
const std::string key = ctx.op().Input("Out");
229+
const std::string key_fc_pd = key + "@fc_pd";
230+
231+
auto pd =
232+
std::static_pointer_cast<mkldnn::inner_product_forward::primitive_desc>(
233+
dev_ctx.GetBlob(key_fc_pd));
234+
235+
PADDLE_ENFORCE(pd != nullptr, "Fail to find key_fc_pd in device context");
236+
237+
if (w_grad) {
238+
auto weights_grad_memory = mem.weights(w_grad_data);
239+
240+
mkldnn::inner_product_backward_weights::primitive_desc bwd_weight_pd =
241+
FcBwdWeightsPrimitiveDesc(md.src(), md.weights(), md.dst(), md.bias(),
242+
with_bias, *pd, mkldnn_engine);
243+
244+
auto bwd_weights_prim = mkldnn::inner_product_backward_weights(
245+
bwd_weight_pd, src_memory, dst_memory, weights_grad_memory,
246+
bias_memory);
247+
248+
std::vector<mkldnn::primitive> pipeline{bwd_weights_prim};
249+
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
250+
}
251+
252+
if (input_grad) {
253+
auto src_grad_memory = mem.src(input_grad_data);
254+
255+
mkldnn::inner_product_backward_data::primitive_desc bwd_data_pd =
256+
FcBwdDataPrimitiveDesc(md.src(), md.weights(), md.dst(), *pd,
257+
mkldnn_engine);
258+
259+
auto bwd_data_prim = mkldnn::inner_product_backward_data(
260+
bwd_data_pd, dst_memory, weights_memory, src_grad_memory);
261+
262+
std::vector<mkldnn::primitive> pipeline{bwd_data_prim};
263+
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
264+
}
265+
}
266+
267+
private:
268+
mkldnn::inner_product_backward_weights::primitive_desc
269+
FcBwdWeightsPrimitiveDesc(
270+
const mkldnn::memory::desc& src, const mkldnn::memory::desc& diff_weights,
271+
const mkldnn::memory::desc& diff_dst, const mkldnn::memory::desc& bias,
272+
const bool with_bias,
273+
const mkldnn::inner_product_forward::primitive_desc& pd,
274+
const mkldnn::engine& engine) const {
275+
auto bwd_weight_desc = with_bias
276+
? mkldnn::inner_product_backward_weights::desc(
277+
src, diff_weights, bias, diff_dst)
278+
: mkldnn::inner_product_backward_weights::desc(
279+
src, diff_weights, bias, diff_dst);
280+
281+
return mkldnn::inner_product_backward_weights::primitive_desc(
282+
bwd_weight_desc, engine, pd);
283+
}
284+
285+
mkldnn::inner_product_backward_data::primitive_desc FcBwdDataPrimitiveDesc(
286+
const mkldnn::memory::desc& diff_src, const mkldnn::memory::desc& weights,
287+
const mkldnn::memory::desc& diff_dst,
288+
const mkldnn::inner_product_forward::primitive_desc& pd,
289+
const mkldnn::engine& engine) const {
290+
auto bwd_data_desc =
291+
mkldnn::inner_product_backward_data::desc(diff_src, weights, diff_dst);
292+
return mkldnn::inner_product_backward_data::primitive_desc(bwd_data_desc,
293+
engine, pd);
294+
}
295+
};
296+
} // namespace operators
297+
} // namespace paddle
298+
299+
REGISTER_OP_KERNEL(fc, MKLDNN, ::paddle::platform::CPUPlace,
300+
paddle::operators::FCMKLDNNOpKernel<float>);
301+
302+
REGISTER_OP_KERNEL(fc_grad, MKLDNN, ::paddle::platform::CPUPlace,
303+
paddle::operators::FCMKLDNNGradOpKernel<float>);

0 commit comments

Comments
 (0)