@@ -280,12 +280,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
280280 * ('any') which lets a primitive (convolution in this case) choose
281281 * the memory format preferred for best performance
282282 */
283+ std::string data_format = ctx.Attr <std::string>(" data_format" );
284+ auto chosen_memory_format =
285+ platform::data_format_to_memory_format (data_format);
286+
283287 auto src_md = platform::MKLDNNMemDesc (
284- src_tz, platform::MKLDNNGetDataType<T>(), memory::format::any );
288+ src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format );
285289 auto weights_md = platform::MKLDNNMemDesc (
286- weights_tz, platform::MKLDNNGetDataType<T>(), memory::format::any );
290+ weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format );
287291 auto dst_md = platform::MKLDNNMemDesc (
288- dst_tz, platform::MKLDNNGetDataType<T>(), memory::format::any );
292+ dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format );
289293
290294 // create a conv primitive descriptor and save it for usage in backward
291295 std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd =
@@ -423,16 +427,20 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
423427 * ('any') which lets a primitive (conv backward in this case) choose
424428 * the memory format preferred for best performance
425429 */
430+ std::string data_format = ctx.Attr <std::string>(" data_format" );
431+ auto chosen_memory_format =
432+ platform::data_format_to_memory_format (data_format);
433+
426434 auto src_md = platform::MKLDNNMemDesc (
427- src_tz, platform::MKLDNNGetDataType<T>(), memory::format::any );
435+ src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format );
428436 auto diff_src_md = platform::MKLDNNMemDesc (
429- src_tz, platform::MKLDNNGetDataType<T>(), memory::format::any );
437+ src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format );
430438 auto weights_md = platform::MKLDNNMemDesc (
431- weights_tz, platform::MKLDNNGetDataType<T>(), memory::format::any );
439+ weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format );
432440 auto diff_weights_md = platform::MKLDNNMemDesc (
433- weights_tz, platform::MKLDNNGetDataType<T>(), memory::format::any );
441+ weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format );
434442 auto diff_dst_md = platform::MKLDNNMemDesc (
435- dst_tz, platform::MKLDNNGetDataType<T>(), memory::format::any );
443+ dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format );
436444
437445 // Retrieve conv_pd from device context
438446 auto conv_pd =
0 commit comments