Skip to content

Commit cd32dda

Browse files
Sand3r-luotao1
authored andcommitted
Fuse Convolution and Eltwise Add into MKLDNN's Conv+Bias (#12669)
* Fuse Convolution and Eltwise Add into Conv+Bias * Reduce bias branching at conv_mkldnn_op * Add MKLDNN build checks for Conv Bias * Conv-bias: check if bias input exist befor assignment * Conv-bias: Remove Bias dim check from infershape It was causing conv3d test to crash upon\ncalling HasInput(Bias)
1 parent 896a37b commit cd32dda

File tree

4 files changed

+200
-18
lines changed

4 files changed

+200
-18
lines changed

paddle/fluid/operators/conv_mkldnn_op.cc

Lines changed: 93 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,15 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
126126
pipeline);
127127
}
128128

129+
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryFromPrimitive(
130+
const std::shared_ptr<mkldnn::memory> user_bias_memory_p,
131+
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
132+
auto user_bias_pd = user_bias_memory_p->get_primitive_desc();
133+
auto bias_pd = conv_pd_->bias_primitive_desc();
134+
return this->AcquireMemory(bias_pd, user_bias_pd, user_bias_memory_p,
135+
"@bias_mem_p", pipeline);
136+
}
137+
129138
std::shared_ptr<mkldnn::convolution_forward> AcquireConvolution(
130139
std::shared_ptr<mkldnn::memory> src_memory_p,
131140
std::shared_ptr<mkldnn::memory> weights_memory_p,
@@ -147,6 +156,28 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
147156
return conv_p;
148157
}
149158

159+
std::shared_ptr<mkldnn::convolution_forward> AcquireConvolution(
160+
std::shared_ptr<mkldnn::memory> src_memory_p,
161+
std::shared_ptr<mkldnn::memory> weights_memory_p,
162+
std::shared_ptr<mkldnn::memory> bias_memory_p,
163+
std::shared_ptr<mkldnn::memory> dst_memory_p) {
164+
auto prim_key = key_ + "@conv_p";
165+
auto conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
166+
dev_ctx_.GetBlob(prim_key));
167+
PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false),
168+
"Fail to find convolution primitive in device context");
169+
if (conv_p == nullptr) {
170+
conv_p = std::make_shared<mkldnn::convolution_forward>(
171+
*conv_pd_, *(src_memory_p), *(weights_memory_p.get()),
172+
*(bias_memory_p.get()), *(dst_memory_p.get()));
173+
174+
dev_ctx_.SetBlob(prim_key, conv_p);
175+
} else {
176+
is_reusing_ = true;
177+
}
178+
return conv_p;
179+
}
180+
150181
std::shared_ptr<mkldnn::convolution_backward_weights>
151182
AcquireConvolutionBackwardWeights(
152183
std::shared_ptr<mkldnn::memory> src_memory_p,
@@ -229,6 +260,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
229260

230261
auto* input = ctx.Input<Tensor>("Input");
231262
auto* filter = ctx.Input<Tensor>("Filter");
263+
auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
232264
auto* output = ctx.Output<Tensor>("Output");
233265

234266
PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN &&
@@ -237,6 +269,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
237269
PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN &&
238270
filter->format() != memory::format::format_undef,
239271
"Wrong layout/format set for Filter tensor");
272+
PADDLE_ENFORCE(input->dims().size() == 4,
273+
"Input must be with 4 dimensions, i.e. NCHW");
274+
PADDLE_ENFORCE(filter->dims().size() == 4,
275+
"Filter must be with 4 dimensions, i.e. OIHW");
276+
if (bias) {
277+
PADDLE_ENFORCE(bias->layout() == DataLayout::kMKLDNN &&
278+
bias->format() != memory::format::format_undef,
279+
"Wrong layout/format set for Bias tensor");
280+
PADDLE_ENFORCE(bias->dims().size() == 1,
281+
"Bias must only have 1 dimension, i.e. X");
282+
}
240283

241284
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
242285
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
@@ -253,11 +296,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
253296
const T* filter_data = filter->data<T>();
254297
T* output_data = output->mutable_data<T>(ctx.GetPlace());
255298

256-
PADDLE_ENFORCE(input->dims().size() == 4,
257-
"Input must be with 4 dimensions, i.e. NCHW");
258-
PADDLE_ENFORCE(filter->dims().size() == 4,
259-
"Filter must be with 4 dimensions, i.e. OIHW");
260-
261299
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
262300
std::vector<int> weights_tz =
263301
paddle::framework::vectorize2int(filter->dims());
@@ -288,13 +326,23 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
288326
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
289327
auto weights_md = platform::MKLDNNMemDesc(
290328
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
329+
std::vector<int> bias_tz; // TODO(mgallus): avoid empty vector creation.
330+
// Currently used whenever bias is != nullptr.
291331
auto dst_md = platform::MKLDNNMemDesc(
292332
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
293333

294334
// create a conv primitive descriptor and save it for usage in backward
295-
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd =
296-
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
297-
mkldnn_engine);
335+
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
336+
if (bias) {
337+
bias_tz = paddle::framework::vectorize2int(bias->dims());
338+
auto bias_md = platform::MKLDNNMemDesc(
339+
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
340+
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
341+
strides, paddings, mkldnn_engine);
342+
} else {
343+
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides,
344+
paddings, mkldnn_engine);
345+
}
298346
// Save conv_pd/src_memory/weights_memory for backward pass
299347
dev_ctx.SetBlob(key_conv_pd, conv_pd);
300348

@@ -315,8 +363,22 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
315363
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
316364

317365
// create convolution op primitive
318-
auto conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
319-
dst_memory_p);
366+
std::shared_ptr<mkldnn::convolution_forward> conv_p;
367+
if (bias) {
368+
const T* bias_data = bias->data<T>();
369+
auto user_bias_md = platform::MKLDNNMemDesc(
370+
{bias_tz}, platform::MKLDNNGetDataType<T>(), memory::format::x);
371+
auto user_bias_memory_p =
372+
handler.AcquireBiasMemory(user_bias_md, to_void_cast<T>(bias_data));
373+
374+
auto bias_memory_p =
375+
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline);
376+
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
377+
bias_memory_p, dst_memory_p);
378+
} else {
379+
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
380+
dst_memory_p);
381+
}
320382

321383
// push primitive to stream and wait until it's executed
322384
pipeline.push_back(*conv_p);
@@ -346,6 +408,27 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
346408
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
347409
p_conv_pd);
348410
}
411+
412+
std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
413+
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
414+
const memory::desc& bias, const memory::desc& dst,
415+
const std::vector<int>& strides,
416+
const std::vector<int>& paddings,
417+
const mkldnn::engine& engine) const {
418+
memory::dims stride_dims = {strides[0], strides[1]};
419+
memory::dims padding_dims = {paddings[0], paddings[1]};
420+
421+
auto conv_desc = mkldnn::convolution_forward::desc(
422+
mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights,
423+
bias, dst, stride_dims, padding_dims, padding_dims,
424+
mkldnn::padding_kind::zero);
425+
426+
auto p_conv_pd =
427+
new mkldnn::convolution_forward::primitive_desc(conv_desc, engine);
428+
429+
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
430+
p_conv_pd);
431+
}
349432
};
350433

351434
template <typename T>

paddle/fluid/operators/conv_op.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
3737

3838
auto in_dims = ctx->GetInputDim("Input");
3939
auto filter_dims = ctx->GetInputDim("Filter");
40+
4041
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
4142
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
4243
int groups = ctx->Attrs().Get<int>("groups");
@@ -57,7 +58,6 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
5758
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[1] * groups,
5859
"The number of input channels should be equal to filter "
5960
"channels * groups.");
60-
6161
PADDLE_ENFORCE_EQ(
6262
filter_dims[0] % groups, 0,
6363
"The number of output channels should be divided by groups.");
@@ -122,6 +122,11 @@ void Conv2DOpMaker::Make() {
122122
"H is the height of the filter, and W is the width of the filter. "
123123
"If the groups attribute is greater than 1, C equals the number of "
124124
"input image channels divided by the groups.");
125+
AddInput("Bias",
126+
"(Tensor) Bias to be added to each output of filter application."
127+
"The format of output tensor is X (one-dimensional) of size equal"
128+
"to the number of output channels. Only used with MKL-DNN.")
129+
.AsDispensable();
125130
AddOutput("Output",
126131
"(Tensor) The output tensor of convolution operator. "
127132
"The format of output tensor is also NCHW.")

paddle/fluid/platform/mkldnn_helper.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ class MKLDNNHandler {
125125
return this->AcquireMemory(md, ptr, "@user_weights_mem_p");
126126
}
127127

128+
std::shared_ptr<mkldnn::memory> AcquireBiasMemory(
129+
const mkldnn::memory::desc& md, void* ptr) {
130+
return this->AcquireMemory(md, ptr, "@user_bias_mem_p");
131+
}
132+
128133
std::shared_ptr<mkldnn::memory> AcquireDstMemory(
129134
const mkldnn::memory::desc& md, void* ptr) {
130135
return this->AcquireMemory(md, ptr, "@user_dst_mem_p");

python/paddle/fluid/transpiler/inference_transpiler.py

Lines changed: 96 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,12 @@ def transpile(self, program, place, scope=None):
5959
scope = global_scope()
6060
if not isinstance(scope, core.Scope):
6161
raise TypeError("scope should be as Scope type or None")
62-
self._fuse_batch_norm(program, place, scope)
63-
self._fuse_relu_mkldnn(program)
62+
use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False))
63+
if use_mkldnn:
64+
self._fuse_relu_mkldnn(program)
65+
self._fuse_conv_bias_mkldnn(program)
66+
else:
67+
self._fuse_batch_norm(program, place, scope)
6468

6569
def _fuse_relu_mkldnn(self, program):
6670
'''
@@ -82,10 +86,6 @@ def _fuse_relu_mkldnn(self, program):
8286
:param program: program to transpile
8387
:type program: Program
8488
'''
85-
use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False))
86-
if not use_mkldnn:
87-
return
88-
8989
self.block = program.block(0)
9090

9191
i = 0
@@ -106,6 +106,69 @@ def _fuse_relu_mkldnn(self, program):
106106
# And a better solution will be considered later.
107107
program = program.clone()
108108

109+
def _fuse_conv_bias_mkldnn(self, program):
110+
'''
111+
Transpile the program by fused convolution and elementwise_add.
112+
113+
Replace conv2d and elementwise_add ops with a new conv2d op
114+
based on an old conv2d op and the :math:`Bias` taken from
115+
elementwise_add.
116+
117+
For input :math:`X`:
118+
119+
- Conv process: :math:`X = input * W`
120+
- Elementwise_add process: :math` X = X + bias`
121+
122+
After fuse into one operation:
123+
124+
.. math::
125+
126+
X = input * W + bias
127+
128+
The operator transformation is:
129+
130+
- before:
131+
132+
- conv->elementwise_add->any_other_op
133+
134+
- after:
135+
136+
- conv->any_other_op
137+
138+
The transpile stages are:
139+
140+
1. Extract bias and output variables from elementwise_add.
141+
2. Extract Input, Weight and attributes from conv op.
142+
3. Create a new convolution op based on extracted params.
143+
4. Remove old conv op.
144+
5. Remove elementwise_add.
145+
5. Remove unused variables.
146+
147+
Args:
148+
program (Program): program to transpile
149+
150+
'''
151+
self.block = program.block(0)
152+
153+
i = 0
154+
while i < len(self.block.ops) - 2:
155+
current_op = self.block.ops[i]
156+
next_op = self.block.ops[i + 1]
157+
# conv2d with bias
158+
if current_op.type in ['conv2d'] and \
159+
next_op.type in ['elementwise_add']:
160+
self._fuse_conv_bias(i, current_op, next_op)
161+
self.block._remove_op(i + 1) # Remove old conv
162+
self.block._remove_op(i + 1) # Remove elementwise_add
163+
i = i + 1
164+
i = i + 1
165+
166+
self._remove_unused_var()
167+
# TODO(luotao): use clone() method to flush the program.desc in force,
168+
# since some large program.desc will not be flushed immediately.
169+
# And a better solution will be considered later.
170+
program = program.clone()
171+
109172
def _fuse_batch_norm(self, program, place, scope):
110173
'''
111174
Transpile the program by fused batch normalization.
@@ -185,7 +248,6 @@ def _fuse_batch_norm(self, program, place, scope):
185248
self.block._remove_op(i + 2)
186249
i = i + 1
187250
i = i + 1
188-
189251
self._adjust_input()
190252
self._remove_unused_var()
191253
# TODO(luotao): use clone() method to flush the program.desc in force,
@@ -288,6 +350,33 @@ def _load_param(param_name):
288350
# collect the renamed input
289351
self.input_map[bn_op.output("Y")[0]] = bias_op.output("Out")[0]
290352

353+
def _fuse_conv_bias(self, index, conv_op, elementwise_add_op):
354+
'''
355+
fuse the conv op with elementwise_add
356+
357+
:param index: index of the conv_op in ops list
358+
:type index: Int
359+
:param conv_op: convolution operator
360+
:type conv_op: Operator
361+
:param elementwise_add_op: convolution's bias operator
362+
:type elementwise_add_op: Operator
363+
'''
364+
365+
bias_var = self.block.var(elementwise_add_op.input("Y")[0])
366+
out_var = self.block.var(elementwise_add_op.output("Out")[0])
367+
filter_var = self.block.var(conv_op.input("Filter")[0])
368+
in_var = self.block.var(conv_op.input("Input")[0])
369+
attrs = {name: conv_op.attr(name) for name in conv_op.attr_names}
370+
371+
self.block._insert_op(
372+
index,
373+
type="conv2d",
374+
inputs={"Input": in_var,
375+
"Filter": filter_var,
376+
"Bias": bias_var},
377+
outputs={"Output": out_var},
378+
attrs=attrs)
379+
291380
def _adjust_input(self):
292381
for i in range(len(self.block.ops)):
293382
current_op = self.block.ops[i]

0 commit comments

Comments
 (0)