Skip to content

Commit ae2cd61

Browse files
digantdesaifacebook-github-bot
authored andcommitted
Add support for quantized transposed conv
Summary: As title. Updates the tests accordingly. Differential Revision: D68939306
1 parent 3413971 commit ae2cd61

File tree

3 files changed

+32
-17
lines changed

3 files changed

+32
-17
lines changed

backends/xnnpack/quantizer/xnnpack_quantizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,9 @@ class XNNPACKQuantizer(Quantizer):
249249
STATIC_OPS = [
250250
"linear_relu",
251251
"linear",
252-
"conv_relu",
253252
"conv",
253+
"conv_transpose",
254+
"conv_relu",
254255
"conv_transpose_relu",
255256
"adaptive_avg_pool2d",
256257
# TODO: move this to BoltNNQuantizer?

backends/xnnpack/quantizer/xnnpack_quantizer_utils.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -285,18 +285,17 @@ def _annotate_linear_relu(
285285
return annotated_partitions
286286

287287

288-
@register_annotator("conv")
289-
def _annotate_conv(
288+
def _do_annotate_conv(
290289
gm: torch.fx.GraphModule,
291290
quantization_config: Optional[QuantizationConfig],
292291
filter_fn: Optional[Callable[[Node], bool]] = None,
292+
is_conv_transpose: bool = False,
293293
) -> Optional[list[list[Node]]]:
294294
annotated_partitions = []
295+
is_conv_node = _is_conv_transpose_node if is_conv_transpose else _is_conv_node
296+
295297
for n in gm.graph.nodes:
296-
if n.op != "call_function" or n.target not in [
297-
torch.ops.aten.conv1d.default,
298-
torch.ops.aten.conv2d.default,
299-
]:
298+
if not is_conv_node(n):
300299
continue
301300
conv_node = n
302301

@@ -392,6 +391,27 @@ def _do_annotate_conv_relu(
392391
annotated_partitions.append(partition)
393392
return annotated_partitions
394393

394+
@register_annotator("conv")
395+
def _annotate_conv(
396+
gm: torch.fx.GraphModule,
397+
quantization_config: Optional[QuantizationConfig],
398+
filter_fn: Optional[Callable[[Node], bool]] = None,
399+
) -> Optional[list[list[Node]]]:
400+
return _do_annotate_conv(
401+
gm, quantization_config, filter_fn, is_conv_transpose=False
402+
)
403+
404+
405+
@register_annotator("conv_transpose")
406+
def _annotate_transpose_conv(
407+
gm: torch.fx.GraphModule,
408+
quantization_config: Optional[QuantizationConfig],
409+
filter_fn: Optional[Callable[[Node], bool]] = None,
410+
) -> Optional[list[list[Node]]]:
411+
return _do_annotate_conv(
412+
gm, quantization_config, filter_fn, is_conv_transpose=True
413+
)
414+
395415

396416
@register_annotator("conv_relu")
397417
def _annotate_conv_relu(

backends/xnnpack/test/ops/test_conv2d.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -243,15 +243,14 @@ def test_qs8_conv2d_test(self) -> None:
243243
self._test(
244244
Conv2d(bias=has_bias, transpose=transpose),
245245
quant_config=get_symmetric_quantization_config(),
246-
check_quantized=not transpose, # XNNPackQuantizer does not quantize this pattern yet
247246
)
248247

249248
def test_qs8_conv2d_per_channel(self) -> None:
250249
for transpose in (True, False):
251250
self._test(
252251
Conv2d(transpose=transpose),
253252
quant_config=get_symmetric_quantization_config(is_per_channel=True),
254-
check_quantized=not transpose, # XNNPackQuantizer does not quantize this pattern yet
253+
delegated=not transpose, # XNNPACK does not support per input channel quantization for transpose convolutions with groups > 1
255254
)
256255

257256
def test_fp32_conv2d_seq(self) -> None:
@@ -264,7 +263,6 @@ def test_qs8_conv2d_seq(self) -> None:
264263
Conv2dSeq(transpose=transpose),
265264
conv_count=2,
266265
quant_config=get_symmetric_quantization_config(),
267-
check_quantized=not transpose, # XNNPackQuantizer does not quantize this pattern yet
268266
)
269267

270268
def test_fp32_conv2d_single_int_params(self):
@@ -282,7 +280,6 @@ def test_fp32_conv2d_depthwise(self):
282280
# - Groups must equal In Channels
283281
# - Out Channels must be a positive multiple of In Channels
284282
for transpose in (True, False):
285-
286283
self._test(
287284
Conv2d(groups=2, in_channels=2, out_channels=6, transpose=transpose)
288285
)
@@ -292,7 +289,6 @@ def test_qs8_conv2d_depthwise(self):
292289
self._test(
293290
Conv2d(groups=2, in_channels=2, out_channels=6, transpose=transpose),
294291
quant_config=get_symmetric_quantization_config(),
295-
check_quantized=not transpose, # XNNPackQuantizer does not quantize this pattern yet
296292
)
297293

298294
def test_fp32_conv2d_bn(self):
@@ -384,7 +380,6 @@ def test_qs8_conv2d_bn(self):
384380
Conv2dBatchNorm(transpose=transpose),
385381
quant_config=get_symmetric_quantization_config(),
386382
conv_count=2,
387-
check_quantized=not transpose, # XNNPackQuantizer does not quantize this pattern yet
388383
)
389384

390385
def test_qs8_conv2d_relu(self):
@@ -415,7 +410,7 @@ def get_inputs(self):
415410
self._test(
416411
ConvReLU(transpose=transpose),
417412
quant_config=get_symmetric_quantization_config(is_per_channel=True),
418-
delegated=not transpose,
413+
delegated=not transpose, # XNNPACK does not support per input channel quantization for transpose convolutions with groups > 1
419414
)
420415

421416
def test_qs8_conv2d_dw_relu(self):
@@ -467,9 +462,8 @@ def get_inputs(self):
467462
quant_config=get_symmetric_quantization_config(
468463
is_per_channel=per_channel_quant
469464
),
470-
# xnnpack only supports per output channel quantization for transposed convolutions
471-
# XNNPackQuantizer quantizes per input channel currently
472-
delegated=not transpose or not per_channel_quant,
465+
# XNNPACK does not support per input channel quantization for transpose convolutions with groups > 1
466+
delegated=not (transpose and per_channel_quant),
473467
)
474468

475469
def test_qs8_conv2d_relu_seq(self):

0 commit comments

Comments
 (0)