Skip to content

Commit 1052a79

Browse files
committed
support group convolution layer with mkldnn.
1 parent ae7fb2a commit 1052a79

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

paddle/fluid/operators/conv_mkldnn_op.cc

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
302302
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
303303
int groups = ctx.Attr<int>("groups");
304304

305-
// TODO(pzelazko-intel) add support for group convolution and dilation
306-
PADDLE_ENFORCE(groups == 1, "group convolution is not implemented yet");
305+
// TODO: add support for dilation
307306
PADDLE_ENFORCE(
308307
dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
309308
"dilation in convolution is not implemented yet");
@@ -314,6 +313,19 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
314313
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
315314
std::vector<int> weights_tz =
316315
paddle::framework::vectorize2int(filter->dims());
316+
int g = std::max(groups, 1);
317+
if (g > 1) {
318+
int o = weights_tz[0];
319+
int i = weights_tz[1];
320+
int h = weights_tz[2];
321+
int w = weights_tz[3];
322+
weights_tz.resize(5);
323+
weights_tz[0] = g;
324+
weights_tz[1] = o / g;
325+
weights_tz[2] = i;
326+
weights_tz[3] = h;
327+
weights_tz[4] = w;
328+
}
317329
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
318330

319331
// Get unique name for storing MKLDNN primitives
@@ -327,7 +339,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
327339
auto user_src_md = platform::MKLDNNMemDesc(
328340
{src_tz}, platform::MKLDNNGetDataType<T>(), input->format());
329341
auto user_weights_md = platform::MKLDNNMemDesc(
330-
{weights_tz}, platform::MKLDNNGetDataType<T>(), filter->format());
342+
{weights_tz}, platform::MKLDNNGetDataType<T>(),
343+
(g == 1) ? filter->format() : mkldnn::memory::format::goihw);
331344

332345
/* create memory descriptor for convolution without specified format
333346
* ('any') which lets a primitive (convolution in this case) choose
@@ -340,7 +353,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
340353
auto src_md = platform::MKLDNNMemDesc(
341354
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
342355
auto weights_md = platform::MKLDNNMemDesc(
343-
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
356+
weights_tz, platform::MKLDNNGetDataType<T>(),
357+
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw);
344358
std::vector<int> bias_tz; // TODO(mgallus): avoid empty vector creation.
345359
// Currently used whenever bias is != nullptr.
346360
auto dst_md = platform::MKLDNNMemDesc(

0 commit comments

Comments
 (0)