Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/xnnpack/quantizer/xnnpack_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class XNNPACKQuantizer(Quantizer):
QuantPattern("linear_relu", False, False, LINEAR_TARGETS),
QuantPattern("linear", True, False, LINEAR_TARGETS),
QuantPattern("conv", True, False, CONV_TARGETS),
QuantPattern("conv_transpose", False, False, CONV_TARGETS),
QuantPattern("conv_transpose", True, False, CONV_TARGETS),
QuantPattern("conv_relu", False, False, CONV_TARGETS),
QuantPattern("conv_transpose_relu", False, False, CONV_TARGETS),
QuantPattern("adaptive_avg_pool2d", False, False, ADAPTIVE_AVG_POOL2D_TARGETS),
Expand Down
72 changes: 59 additions & 13 deletions backends/xnnpack/quantizer/xnnpack_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

import torch
import torch.nn.functional as F
from executorch.backends.xnnpack.utils.utils import is_depthwise_conv
from executorch.backends.xnnpack.utils.utils import (
get_groups_from_conv,
is_depthwise_conv,
)
from torch._subclasses import FakeTensor
from torch.fx import Node
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
Expand Down Expand Up @@ -65,6 +68,28 @@ def decorator(annotator: AnnotatorType) -> None:
return decorator


def change_quantization_config(
original_qspec,
dtype=None,
quant_min=None,
quant_max=None,
qscheme=None,
ch_axis=None,
is_dynamic=None,
observer_or_fake_quant_ctr=None,
):
return QuantizationSpec(
dtype=dtype or original_qspec.dtype,
quant_min=quant_min or original_qspec.quant_min,
quant_max=quant_max or original_qspec.quant_max,
qscheme=qscheme or original_qspec.qscheme,
ch_axis=ch_axis or original_qspec.ch_axis,
is_dynamic=is_dynamic or original_qspec.is_dynamic,
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr
or original_qspec.observer_or_fake_quant_ctr,
)


def is_relu_node(node: Node) -> bool:
"""
Check if a given node is a relu node
Expand Down Expand Up @@ -231,31 +256,44 @@ def _do_annotate_conv(
if is_relu_node(user):
continue

# Tracks conditions for whether or not to skip
skip = False

input_qspec_map = {}
input_act = conv_node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = get_input_act_qspec(quantization_config)

weight = conv_node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)
weight_qspec = get_weight_qspec(quantization_config)
num_groups = get_groups_from_conv(conv_node)

# Only annotate dynamically quantized conv if it's 2D and not depthwise
if (
# skip if transposed conv has more than 1 group
skip = skip or (is_conv_transpose and num_groups != 1)
print(f"{skip} conv transpose and num_groups")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove


if is_conv_transpose:
# transposed convs per output channel quantization
weight_qspec = change_quantization_config(weight_qspec, ch_axis=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we support group > 1 here in the quant flow? I know we are checking above but just curious. If yes, we can move this before the group check, if not then add an assert to avoid future issues when we allow groups > 1 for transposed_conv.


input_qspec_map[weight] = weight_qspec
is_dynamic = (
quantization_config
and quantization_config.input_activation
and quantization_config.input_activation.is_dynamic
):
)

# Only annotate dynamically quantized conv if it's 2D and not depthwise
if is_dynamic:
weight_val = weight.meta.get("val", None)
weight_shape = getattr(weight_val, "shape", None)

# Skip if not a 4D weight tensor (i.e. not conv2d)
if weight_shape is not None and len(weight_shape) != 4:
continue

skip = skip or (weight_shape is not None and len(weight_shape) != 4)
# Skip if depthwise (default to groups=1 since it's not an arg)
if is_depthwise_conv(weight_shape, 1, is_conv_transpose):
continue
skip = skip or (
not is_conv_transpose and is_depthwise_conv(weight_shape, 1, False)
)

# adding weight node to the partition as well
partition = [conv_node, conv_node.args[1]]
Expand All @@ -265,7 +303,7 @@ def _do_annotate_conv(
input_qspec_map[bias] = get_bias_qspec(quantization_config)
partition.append(bias)

if _is_annotated(partition):
if _is_annotated(partition) or skip:
continue

if filter_fn and any(not filter_fn(n) for n in partition):
Expand Down Expand Up @@ -311,7 +349,12 @@ def _do_annotate_conv_relu(

weight = conv_node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)
weight_qspec = get_weight_qspec(quantization_config)
groups = get_groups_from_conv(conv_node)
if is_conv_transpose:
# transposed convs per output channel quantization
weight_qspec = change_quantization_config(weight_qspec, ch_axis=1)
input_qspec_map[weight] = weight_qspec

# adding weight node to the partition as well
partition = [relu_node, conv_node, conv_node.args[1]]
Expand All @@ -323,6 +366,9 @@ def _do_annotate_conv_relu(
if _is_annotated(partition):
continue

if is_conv_transpose and groups != 1:
continue

if filter_fn and any(not filter_fn(n) for n in partition):
continue

Expand Down
Loading
Loading