Skip to content

Commit 35a7ec5

Browse files
author
Zhu Yuhua
authored
change qconv to channels-last (#4414) (#5006)
* change conv to cl on atsm * force qconv to cl fmt * clean code * fix format issue * add cl_3d and fix ut * check if wgh is cf Signed-off-by: zhuyuhua-v <[email protected]>
1 parent d043124 commit 35a7ec5

File tree

2 files changed

+25
-39
lines changed

2 files changed

+25
-39
lines changed

csrc/gpu/oneDNN/QConv.h

Lines changed: 19 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -99,20 +99,16 @@ static std::tuple<memory::desc, memory::desc, memory::desc> qconv_get_plain_md(
9999
memory::desc wgh_usr_md,
100100
memory::desc dst_usr_md,
101101
memory::dims wgh_tz,
102-
bool is_channels_last_suggested) {
102+
const int64_t ndim,
103+
int64_t groups,
104+
bool is_wgh_channels_last) {
103105
// create memory desc for conv primitive and query the blocked format
104106
memory::desc src_md, wgh_md, dst_md;
105107
src_md = src_usr_md;
106108
dst_md = dst_usr_md;
107-
if (is_channels_last_suggested) {
108-
// TODO: remove this path when oneDNN fix the accuracy issue.
109-
// in ChannelsLast senario, fmt_wgh should be nhwc instead of any
110-
auto fmt_any = memory::format_tag::any;
111-
auto wei_data_t = memory::data_type::s8;
112-
wgh_md = memory::desc(wgh_tz, wei_data_t, fmt_any);
113-
} else {
114-
wgh_md = wgh_usr_md;
115-
}
109+
auto fmt_wgh = conv_wgh_fmt(ndim, groups != 1, is_wgh_channels_last);
110+
auto wei_data_t = memory::data_type::s8;
111+
wgh_md = memory::desc(wgh_tz, wei_data_t, fmt_wgh);
116112
return {src_md, wgh_md, dst_md};
117113
}
118114

@@ -326,21 +322,19 @@ static at::Tensor quantized_convolution(
326322
conv_fwd_pd = convolution_forward::primitive_desc(
327323
const_cast<dnnl_primitive_desc_t>(conv_fwd_pd_t));
328324
} else {
329-
if (is_onednn_layout_suggested) {
330-
std::tie(src_md, wgh_md, dst_md) =
331-
qconv_get_blocked_md(src, src_usr_md, wgh_usr_md, dst_usr_md);
332-
} else {
333-
auto ic = src.size(1);
334-
auto oc = dst.size(1);
335-
memory::dims wgh_tz =
336-
compatible_wgh_dims(ndim, groups, oc, ic, wgh.sizes());
337-
std::tie(src_md, wgh_md, dst_md) = qconv_get_plain_md(
338-
src_usr_md,
339-
wgh_usr_md,
340-
dst_usr_md,
341-
wgh_tz,
342-
is_channels_last_suggested);
343-
}
325+
auto ic = src.size(1);
326+
auto oc = dst.size(1);
327+
memory::dims wgh_tz =
328+
compatible_wgh_dims(ndim, groups, oc, ic, wgh.sizes());
329+
std::tie(src_md, wgh_md, dst_md) = qconv_get_plain_md(
330+
src_usr_md,
331+
wgh_usr_md,
332+
dst_usr_md,
333+
wgh_tz,
334+
ndim,
335+
groups,
336+
wgh.is_contiguous(at::MemoryFormat::ChannelsLast) ||
337+
wgh.is_contiguous(at::MemoryFormat::ChannelsLast3d));
344338

345339
pattr.set_scales_mask(DNNL_ARG_SRC, mask_ac);
346340
pattr.set_scales_mask(DNNL_ARG_WEIGHTS, mask_wgh);
@@ -406,18 +400,6 @@ static at::Tensor quantized_convolution(
406400
src_m = dpcpp_onednn_memory(src_usr_md, engine, src.data_ptr());
407401
dst_m = dpcpp_onednn_memory(dst_usr_md, engine, dst.data_ptr());
408402
wgh_m = dpcpp_onednn_memory(wgh_usr_md, engine, wgh.data_ptr());
409-
if (memory_layout_for_conv == MEMORY_LAYOUT_FOR_CONV::ChannelsLast) {
410-
// TODO: Should remove after oneDNN fix the accuracy issue
411-
auto expected_wgh_md = conv_fwd_pd.weights_desc();
412-
wgh_m = qconv_get_expected_wgh_memory(
413-
wgh,
414-
wgh_blocked,
415-
wgh_usr_md,
416-
expected_wgh_md,
417-
wgh_scales,
418-
engine,
419-
weight_cache_optimization);
420-
}
421403
}
422404

423405
std::unordered_map<int, memory> args;

tests/gpu/examples/test_qconv_channels_last.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ def test_qconv_simple_channels_last(self, dtype=torch.float):
4747

4848
inputs = inputs.to(memory_format=torch.channels_last)
4949
inputs_gpu = inputs.to("xpu")
50-
filters_gpu = filters.to("xpu")
50+
filters_gpu = filters.to("xpu").to(
51+
memory_format=torch.channels_last
52+
)
5153
bias_gpu = bias.to("xpu")
5254

5355
q_inputs_gpu = torch.quantize_per_tensor(
@@ -117,7 +119,9 @@ def test_qconv_simple_channels_last_3d(self, dtype=torch.float):
117119

118120
inputs = inputs.to(memory_format=torch.channels_last_3d)
119121
inputs_gpu = inputs.to("xpu")
120-
filters_gpu = filters.to("xpu")
122+
filters_gpu = filters.to("xpu").to(
123+
memory_format=torch.channels_last_3d
124+
)
121125
bias_gpu = bias.to("xpu")
122126
q_inputs_gpu = torch.quantize_per_tensor(
123127
inputs_gpu, scale_in, zero_point_in, dtype_inputs

0 commit comments

Comments
 (0)