Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 71 additions & 80 deletions backends/xnnpack/quantizer/xnnpack_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import copy
import functools
from typing import Any, Callable, Optional, TYPE_CHECKING
from dataclasses import dataclass
from typing import Any, Callable, Optional, Set, TYPE_CHECKING

import torch
import torch._dynamo as torchdynamo
Expand Down Expand Up @@ -235,37 +236,52 @@ def not_module_type_or_name_filter(n: Node) -> bool:
return not_module_type_or_name_filter


class XNNPACKQuantizer(Quantizer):
supported_config_and_operators = _get_supported_config_and_operators()
STATIC_QAT_ONLY_OPS = [
"conv_bn_relu",
"conv_bn",
"conv_transpose_bn_relu",
"conv_transpose_bn",
]
@dataclass
class QuantPattern:
name: str
is_dynamic: bool
is_qat: bool
op_overloads: Set[torch._ops.OpOverloadPacket]


CONV_TARGETS = {
torch.ops.aten.conv2d.default,
torch.ops.aten.conv1d.default,
torch.ops.aten.convolution.default,
}

LINEAR_TARGETS = {
torch.ops.aten.linear.default,
}

ADAPTIVE_AVG_POOL2D_TARGETS = {torch.ops.aten.adaptive_avg_pool2d.default}

ADD_TARGETS = {torch.ops.aten.add.Tensor}

MUL_TARGETS = {torch.ops.aten.mul.Tensor}

CAT_TARGETS = {torch.ops.aten.cat.default}

# static quantization ops (both PTQ and QAT)
# Preserve the order that fusions come before singular ops
STATIC_OPS = [
"linear_relu",
"linear",
"conv",
"conv_transpose",
"conv_relu",
"conv_transpose_relu",
"adaptive_avg_pool2d",
# TODO: move this to BoltNNQuantizer?
"gru_io_only",
"add_relu",
"add",
"mul_relu",
"mul",
"cat",
]

DYNAMIC_OPS = [
"linear",
"conv",
class XNNPACKQuantizer(Quantizer):
supported_config_and_operators = _get_supported_config_and_operators()
SUPPORTED_PATTERNS = [
QuantPattern("conv_bn_relu", False, True, CONV_TARGETS),
QuantPattern("conv_bn", False, True, CONV_TARGETS),
QuantPattern("conv_transpose_bn_relu", False, True, CONV_TARGETS),
QuantPattern("conv_transpose_bn", False, True, CONV_TARGETS),
QuantPattern("linear_relu", False, False, LINEAR_TARGETS),
QuantPattern("linear", True, False, LINEAR_TARGETS),
QuantPattern("conv", True, False, CONV_TARGETS),
QuantPattern("conv_transpose", False, False, CONV_TARGETS),
QuantPattern("conv_relu", False, False, CONV_TARGETS),
QuantPattern("conv_transpose_relu", False, False, CONV_TARGETS),
QuantPattern("adaptive_avg_pool2d", False, False, ADAPTIVE_AVG_POOL2D_TARGETS),
QuantPattern("add_relu", False, False, ADD_TARGETS),
QuantPattern("add", False, False, ADD_TARGETS),
QuantPattern("mul_relu", False, False, MUL_TARGETS),
QuantPattern("mul", False, False, MUL_TARGETS),
QuantPattern("cat", False, False, CAT_TARGETS),
]

def __init__(self) -> None:
Expand Down Expand Up @@ -347,83 +363,58 @@ def transform_for_annotation(

def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""just handling global spec for now"""
# hacked for handling dynamic linear quant. will fix later.
if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr]
model = self._annotate_for_dynamic_quantization_config(model)
else:
model = self._annotate_for_static_quantization_config(model)
model = self._annotate_for_quantization_config(model)
propagate_annotation(model)
return model

def _annotate_all_static_patterns(
def _annotate_all_patterns(
self,
model: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> torch.fx.GraphModule:
operator_target: Optional[torch._ops.OpOverloadPacket] = None,
):
# TODO: implement the support for None to be canceling out previous annotations
if quantization_config is None:
return model

if quantization_config.is_qat:
for op in self.STATIC_QAT_ONLY_OPS:
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
for op in self.STATIC_OPS:
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
return model
for pattern in self.SUPPORTED_PATTERNS:
if operator_target and operator_target not in pattern.op_overloads:
# if operator_target is specified, skip patterns that aren't
# associated with that target
continue
if quantization_config.input_activation.is_dynamic and pattern.is_dynamic:
OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn)
elif quantization_config.is_qat and pattern.is_qat:
OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn)
elif not quantization_config.input_activation.is_dynamic:
OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn)

def _annotate_all_dynamic_patterns(
self,
model: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> torch.fx.GraphModule:
# TODO: implement the support for None to be canceling out previous annotations
if quantization_config is None:
return model

for op in self.DYNAMIC_OPS:
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
return model

def _annotate_for_static_quantization_config(
def _annotate_for_quantization_config(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
module_name_list = list(self.module_name_config.keys())
for module_name, config in self.module_name_config.items():
self._annotate_all_static_patterns(
self._annotate_all_patterns(
model, config, _get_module_name_filter(module_name)
)

tp_list = list(self.module_type_config.keys())
for module_type, config in self.module_type_config.items():
self._annotate_all_static_patterns(
self._annotate_all_patterns(
model, config, _get_module_type_filter(module_type)
)

self._annotate_all_static_patterns(
model,
self.global_config,
_get_not_module_type_or_name_filter(tp_list, module_name_list),
)
return model

def _annotate_for_dynamic_quantization_config(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
module_name_list = list(self.module_name_config.keys())
for module_name, config in self.module_name_config.items():
self._annotate_all_dynamic_patterns(
model, config, _get_module_name_filter(module_name)
)

tp_list = list(self.module_type_config.keys())
for module_type, config in self.module_type_config.items():
self._annotate_all_dynamic_patterns(
model, config, _get_module_type_filter(module_type)
for op, config in self.operator_type_config.items():
self._annotate_all_patterns(
model,
config,
_get_not_module_type_or_name_filter(tp_list, module_name_list),
op,
)

self._annotate_all_dynamic_patterns(
self._annotate_all_patterns(
model,
self.global_config,
_get_not_module_type_or_name_filter(tp_list, module_name_list),
Expand Down
16 changes: 12 additions & 4 deletions backends/xnnpack/test/quantizer/test_pt2e_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,14 @@ def test_composable_quantizer_linear_conv(self) -> None:
quantization_config_dynamic = get_symmetric_quantization_config(
is_per_channel=False, is_dynamic=True
)
dynamic_quantizer.set_global(quantization_config_dynamic)
dynamic_quantizer.set_operator_type(
torch.ops.aten.linear.default, quantization_config_dynamic
)
static_quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
static_quantizer.set_global(quantization_config)
static_quantizer.set_operator_type(
torch.ops.aten.conv2d.default, quantization_config
)
# Note that dynamic quantization must be applied first here.
# this is because static quantizer also quantizes linear with static qspec
# and if we apply static_quantizer first then dynamic_quantizer cannot be applied
Expand Down Expand Up @@ -271,10 +275,14 @@ def test_embedding_conv_linear_quantization(self) -> None:
quantization_config_dynamic = get_symmetric_quantization_config(
is_per_channel=True, is_dynamic=True
)
dynamic_quantizer.set_global(quantization_config_dynamic)
dynamic_quantizer.set_operator_type(
torch.ops.aten.linear.default, quantization_config_dynamic
)
static_quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
static_quantizer.set_global(quantization_config)
static_quantizer.set_operator_type(
torch.ops.aten.conv2d.default, quantization_config
)
composed_quantizer = ComposableQuantizer(
[embedding_quantizer, dynamic_quantizer, static_quantizer]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ def test_dynamic_linear_with_conv(self):
quantization_config = get_symmetric_quantization_config(
is_per_channel=False, is_dynamic=True
)
quantizer.set_global(quantization_config)
quantizer.set_operator_type(torch.ops.aten.linear.default, quantization_config)
m_eager = TestHelperModules.ConvLinearWPermute().eval()

node_occurrence = {
Expand Down
Loading