Skip to content

Commit 5bdcb24

Browse files
committed
[Quantized DeConv Support] Enable Quantized Transposed Convs with groups==1
Supporting Quantized Transposed Convs with Groups being 1. Previously, There was some added support for Quantized Transposed Convolutions but only when the channel axis is 1 and when the groups is 1. The current Quantizer didn't support this because it only allows quantizaing along the zero dim, which is generally the output channels. However for TransposedConvs, the dimension of the weights are: ``` [in_channels, out_channels/groups, h, w] ``` Since we want to keep quantization along the output channels, we now need to quantize along axis = 1. The reason we require groups to be one is because XNNPACK takes in filters of the dimension: ``` [out_channels, H, W, in_channels/groups] ``` Since we are quantizing along the output channels, in pytorch we expect to have out_channels/groups scales, but in xnnpack we have out_channels scales! Realistically we would need to support this with some affine quantization, where we provide a scale for every group, every out_channel. However for now, we just ensure the constraint where groups == 1. Differential Revision: [D76631781](https://our.internmc.facebook.com/intern/diff/D76631781/) [ghstack-poisoned]
1 parent f7cc72f commit 5bdcb24

File tree

2 files changed

+58
-100
lines changed

2 files changed

+58
-100
lines changed

backends/xnnpack/quantizer/xnnpack_quantizer_utils.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,19 @@ def _do_annotate_conv(
238238

239239
weight = conv_node.args[1]
240240
assert isinstance(weight, Node)
241-
input_qspec_map[weight] = get_weight_qspec(quantization_config)
241+
weight_qspec = get_weight_qspec(quantization_config)
242+
if is_conv_transpose:
243+
# transposed convs per output channel quantization
244+
weight_qspec = QuantizationSpec(
245+
dtype=weight_qspec.dtype,
246+
quant_min=weight_qspec.quant_min,
247+
quant_max=weight_qspec.quant_max,
248+
qscheme=weight_qspec.qscheme,
249+
ch_axis=1,
250+
is_dynamic=False,
251+
observer_or_fake_quant_ctr=weight_qspec.observer_or_fake_quant_ctr,
252+
)
253+
input_qspec_map[weight] = weight_qspec
242254

243255
# Only annotate dynamically quantized conv if it's 2D and not depthwise
244256
if (
@@ -311,7 +323,19 @@ def _do_annotate_conv_relu(
311323

312324
weight = conv_node.args[1]
313325
assert isinstance(weight, Node)
314-
input_qspec_map[weight] = get_weight_qspec(quantization_config)
326+
weight_qspec = get_weight_qspec(quantization_config)
327+
if is_conv_transpose:
328+
# transposed convs per output channel quantization
329+
weight_qspec = QuantizationSpec(
330+
dtype=weight_qspec.dtype,
331+
quant_min=weight_qspec.quant_min,
332+
quant_max=weight_qspec.quant_max,
333+
qscheme=weight_qspec.qscheme,
334+
ch_axis=1,
335+
is_dynamic=False,
336+
observer_or_fake_quant_ctr=weight_qspec.observer_or_fake_quant_ctr,
337+
)
338+
input_qspec_map[weight] = weight_qspec
315339

316340
# adding weight node to the partition as well
317341
partition = [relu_node, conv_node, conv_node.args[1]]

backends/xnnpack/test/ops/test_conv2d.py

Lines changed: 32 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,6 @@ def _test(
221221
conv_count=1,
222222
dtype: torch.dtype = torch.float,
223223
check_quantized=True,
224-
delegated=True,
225224
):
226225
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
227226
# torch.nn.modules.module.Module]` is not a function.
@@ -240,29 +239,20 @@ def _test(
240239

241240
(tester.export().check_count({op: conv_count}).to_edge_transform_and_lower())
242241

243-
if delegated:
244-
(
245-
tester.check_not(
246-
["executorch_exir_dialects_edge__ops_aten_convolution_default"]
247-
)
248-
.check_not(
249-
[
250-
"executorch_exir_dialects_edge__ops__native_batch_norm_legit_no_training_default"
251-
]
252-
)
253-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
254-
.to_executorch()
255-
.serialize()
256-
.run_method_and_compare_outputs(qtol=1)
242+
(
243+
tester.check_not(
244+
["executorch_exir_dialects_edge__ops_aten_convolution_default"]
257245
)
258-
else:
259-
# need quantize ops when ops are not delegated to xnnpack
260-
if has_quantized_ops:
261-
(
262-
tester.to_executorch()
263-
.serialize()
264-
.run_method_and_compare_outputs(qtol=1)
265-
)
246+
.check_not(
247+
[
248+
"executorch_exir_dialects_edge__ops__native_batch_norm_legit_no_training_default"
249+
]
250+
)
251+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
252+
.to_executorch()
253+
.serialize()
254+
.run_method_and_compare_outputs(qtol=1)
255+
)
266256

267257
def _test_dq(
268258
self,
@@ -325,7 +315,6 @@ def test_qs8_conv2d_per_channel(self) -> None:
325315
self._test(
326316
Conv2d(transpose=transpose),
327317
quant_config=get_symmetric_quantization_config(is_per_channel=True),
328-
delegated=not transpose, # XNNPACK does not support per input channel quantization for transpose convolutions with groups > 1
329318
)
330319

331320
def test_fp32_conv2d_seq(self) -> None:
@@ -485,7 +474,6 @@ def get_inputs(self):
485474
self._test(
486475
ConvReLU(transpose=transpose),
487476
quant_config=get_symmetric_quantization_config(is_per_channel=True),
488-
delegated=not transpose, # XNNPACK does not support per input channel quantization for transpose convolutions with groups > 1
489477
)
490478

491479
def test_qs8_conv2d_dw_relu(self):
@@ -537,8 +525,6 @@ def get_inputs(self):
537525
quant_config=get_symmetric_quantization_config(
538526
is_per_channel=per_channel_quant
539527
),
540-
# XNNPACK does not support per input channel quantization for transpose convolutions with groups > 1
541-
delegated=not (transpose and per_channel_quant),
542528
)
543529

544530
def test_qs8_conv2d_relu_seq(self):
@@ -593,7 +579,7 @@ def get_inputs(self):
593579
conv_count=2,
594580
)
595581

596-
def test_qs8_conv_transpose_2d_quantize_per_channel(self):
582+
def test_qs8_conv_transpose_2d_quantize_per_channel_multi_axis(self):
597583
class PerChannelConvTranspose2d(torch.nn.Module):
598584
def __init__(self, input_channels, output_channels, groups, axis):
599585
super().__init__()
@@ -662,76 +648,24 @@ def get_inputs(self):
662648
)
663649

664650
for groups in (1, 2):
665-
for axis in (0, 1):
666-
self._test(
667-
PerChannelConvTranspose2d(3 * groups, 5 * groups, groups, axis),
668-
quant_config=None,
669-
conv_count=1,
670-
delegated=axis == 1
671-
and groups
672-
== 1, # xnnpack only support output channel axis quantization with groups == 1
673-
)
674-
675-
def test_qs8_conv_transpose_2d_dqd_f32_weights(self):
676-
class TransposeConv2dDQDf32weights(torch.nn.Module):
677-
def __init__(self, input_channels, output_channels, groups, axis):
678-
super().__init__()
679-
self.input_channels = input_channels
680-
self.output_channels = output_channels
681-
self.axis = axis
682-
self.groups = groups
683-
self.transpose = True
684-
self.weights = torch.nn.Parameter(
685-
torch.randn((input_channels, output_channels // groups, 4, 4)),
686-
requires_grad=False,
687-
)
688-
689-
axis_size = self.weights.shape[axis]
690-
self.scale = torch.nn.Parameter(torch.ones(axis_size) * 0.12345)
691-
self.zero_point = torch.nn.Parameter(
692-
torch.zeros((axis_size,), dtype=torch.int64), requires_grad=False
693-
)
694-
695-
def forward(self, x):
696-
dequantize_input = (
697-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default(
698-
x, 0.12345, 0, -127, 127, torch.int8
651+
for ch_axis in (1, 2):
652+
if ch_axis == 1 and groups == 1:
653+
self._test(
654+
PerChannelConvTranspose2d(
655+
3 * groups, 5 * groups, groups, ch_axis
656+
), # ch_axis=0
657+
quant_config=None,
658+
conv_count=1,
699659
)
700-
)
701-
x = torch.nn.functional.conv_transpose2d(
702-
dequantize_input, self.weights, groups=self.groups
703-
)
704-
705-
return exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default(
706-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default(
707-
x,
708-
0.12345,
709-
0,
710-
-127,
711-
127,
712-
torch.int8,
713-
),
714-
0.12345,
715-
0,
716-
-127,
717-
127,
718-
torch.int8,
719-
)
720-
721-
def get_inputs(self):
722-
return (
723-
torch.randint(
724-
low=-127, high=127, size=(3, self.input_channels, 4, 4)
725-
).type(dtype=torch.int8),
726-
)
727-
728-
for groups in (1, 2):
729-
for axis in (0, 1):
730-
self._test(
731-
TransposeConv2dDQDf32weights(3 * groups, 5 * groups, groups, axis),
732-
quant_config=None,
733-
conv_count=1,
734-
)
660+
else:
661+
with self.assertRaises(RuntimeError):
662+
self._test(
663+
PerChannelConvTranspose2d(
664+
3 * groups, 5 * groups, groups, ch_axis
665+
), # ch_axis=0
666+
quant_config=None,
667+
conv_count=1,
668+
)
735669

736670
def test_padded_output_tconv(self):
737671
class TConv2d(torch.nn.Module):
@@ -761,7 +695,7 @@ def forward(self, x):
761695

762696
(tester.export().check_count({op: conv_count}).to_edge_transform_and_lower())
763697

764-
# tconv should not be offloaded to XNNPack, since output padding is not
698+
# tconv should not be offloaded to XNNPack, since output padding is not supported
765699
(
766700
tester.check(
767701
["executorch_exir_dialects_edge__ops_aten_convolution_default"]

0 commit comments

Comments
 (0)