@@ -302,8 +302,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
302
302
bool fuse_relu = ctx.Attr <bool >(" fuse_relu" );
303
303
int groups = ctx.Attr <int >(" groups" );
304
304
305
- // TODO(pzelazko-intel) add support for group convolution and dilation
306
- PADDLE_ENFORCE (groups == 1 , " group convolution is not implemented yet" );
305
+ // TODO: add support for dilation
307
306
PADDLE_ENFORCE (
308
307
dilations.size () == 2 && dilations[0 ] == 1 && dilations[1 ] == 1 ,
309
308
" dilation in convolution is not implemented yet" );
@@ -314,6 +313,19 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
314
313
std::vector<int > src_tz = paddle::framework::vectorize2int (input->dims ());
315
314
std::vector<int > weights_tz =
316
315
paddle::framework::vectorize2int (filter->dims ());
316
+ int g = std::max (groups, 1 );
317
+ if (g > 1 ) {
318
+ int o = weights_tz[0 ];
319
+ int i = weights_tz[1 ];
320
+ int h = weights_tz[2 ];
321
+ int w = weights_tz[3 ];
322
+ weights_tz.resize (5 );
323
+ weights_tz[0 ] = g;
324
+ weights_tz[1 ] = o / g;
325
+ weights_tz[2 ] = i;
326
+ weights_tz[3 ] = h;
327
+ weights_tz[4 ] = w;
328
+ }
317
329
std::vector<int > dst_tz = paddle::framework::vectorize2int (output->dims ());
318
330
319
331
// Get unique name for storing MKLDNN primitives
@@ -327,7 +339,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
327
339
auto user_src_md = platform::MKLDNNMemDesc (
328
340
{src_tz}, platform::MKLDNNGetDataType<T>(), input->format ());
329
341
auto user_weights_md = platform::MKLDNNMemDesc (
330
- {weights_tz}, platform::MKLDNNGetDataType<T>(), filter->format ());
342
+ {weights_tz}, platform::MKLDNNGetDataType<T>(),
343
+ (g == 1 ) ? filter->format () : mkldnn::memory::format::goihw);
331
344
332
345
/* create memory descriptor for convolution without specified format
333
346
* ('any') which lets a primitive (convolution in this case) choose
@@ -340,7 +353,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
340
353
auto src_md = platform::MKLDNNMemDesc (
341
354
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
342
355
auto weights_md = platform::MKLDNNMemDesc (
343
- weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
356
+ weights_tz, platform::MKLDNNGetDataType<T>(),
357
+ (g == 1 ) ? chosen_memory_format : mkldnn::memory::format::goihw);
344
358
std::vector<int > bias_tz; // TODO(mgallus): avoid empty vector creation.
345
359
// Currently used whenever bias is != nullptr.
346
360
auto dst_md = platform::MKLDNNMemDesc (
0 commit comments