@@ -99,20 +99,16 @@ static std::tuple<memory::desc, memory::desc, memory::desc> qconv_get_plain_md(
9999 memory::desc wgh_usr_md,
100100 memory::desc dst_usr_md,
101101 memory::dims wgh_tz,
102- bool is_channels_last_suggested) {
102+ const int64_t ndim,
103+ int64_t groups,
104+ bool is_wgh_channels_last) {
103105 // create memory desc for conv primitive and query the blocked format
104106 memory::desc src_md, wgh_md, dst_md;
105107 src_md = src_usr_md;
106108 dst_md = dst_usr_md;
107- if (is_channels_last_suggested) {
108- // TODO: remove this path when oneDNN fix the accuracy issue.
109- // in ChannelsLast senario, fmt_wgh should be nhwc instead of any
110- auto fmt_any = memory::format_tag::any;
111- auto wei_data_t = memory::data_type::s8;
112- wgh_md = memory::desc (wgh_tz, wei_data_t , fmt_any);
113- } else {
114- wgh_md = wgh_usr_md;
115- }
109+ auto fmt_wgh = conv_wgh_fmt (ndim, groups != 1 , is_wgh_channels_last);
110+ auto wei_data_t = memory::data_type::s8;
111+ wgh_md = memory::desc (wgh_tz, wei_data_t , fmt_wgh);
116112 return {src_md, wgh_md, dst_md};
117113}
118114
@@ -326,21 +322,19 @@ static at::Tensor quantized_convolution(
326322 conv_fwd_pd = convolution_forward::primitive_desc (
327323 const_cast <dnnl_primitive_desc_t >(conv_fwd_pd_t ));
328324 } else {
329- if (is_onednn_layout_suggested) {
330- std::tie (src_md, wgh_md, dst_md) =
331- qconv_get_blocked_md (src, src_usr_md, wgh_usr_md, dst_usr_md);
332- } else {
333- auto ic = src.size (1 );
334- auto oc = dst.size (1 );
335- memory::dims wgh_tz =
336- compatible_wgh_dims (ndim, groups, oc, ic, wgh.sizes ());
337- std::tie (src_md, wgh_md, dst_md) = qconv_get_plain_md (
338- src_usr_md,
339- wgh_usr_md,
340- dst_usr_md,
341- wgh_tz,
342- is_channels_last_suggested);
343- }
325+ auto ic = src.size (1 );
326+ auto oc = dst.size (1 );
327+ memory::dims wgh_tz =
328+ compatible_wgh_dims (ndim, groups, oc, ic, wgh.sizes ());
329+ std::tie (src_md, wgh_md, dst_md) = qconv_get_plain_md (
330+ src_usr_md,
331+ wgh_usr_md,
332+ dst_usr_md,
333+ wgh_tz,
334+ ndim,
335+ groups,
336+ wgh.is_contiguous (at::MemoryFormat::ChannelsLast) ||
337+ wgh.is_contiguous (at::MemoryFormat::ChannelsLast3d));
344338
345339 pattr.set_scales_mask (DNNL_ARG_SRC, mask_ac);
346340 pattr.set_scales_mask (DNNL_ARG_WEIGHTS, mask_wgh);
@@ -406,18 +400,6 @@ static at::Tensor quantized_convolution(
406400 src_m = dpcpp_onednn_memory (src_usr_md, engine, src.data_ptr ());
407401 dst_m = dpcpp_onednn_memory (dst_usr_md, engine, dst.data_ptr ());
408402 wgh_m = dpcpp_onednn_memory (wgh_usr_md, engine, wgh.data_ptr ());
409- if (memory_layout_for_conv == MEMORY_LAYOUT_FOR_CONV::ChannelsLast) {
410- // TODO: Should remove after oneDNN fix the accuracy issue
411- auto expected_wgh_md = conv_fwd_pd.weights_desc ();
412- wgh_m = qconv_get_expected_wgh_memory (
413- wgh,
414- wgh_blocked,
415- wgh_usr_md,
416- expected_wgh_md,
417- wgh_scales,
418- engine,
419- weight_cache_optimization);
420- }
421403 }
422404
423405 std::unordered_map<int , memory> args;
0 commit comments