Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/xnnpack/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
)
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
from executorch.backends.xnnpack._passes.replace_u8_convert_with_dq_pass import (
ReplaceU8ConvertWithDqPass,
)
from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import (
TagImplicitQDqPass,
)
Expand Down Expand Up @@ -70,6 +73,7 @@ def __init__(
PReLUReshapePass,
ChannelsLastTaggedReshapePass,
TagImplicitQDqPass,
ReplaceU8ConvertWithDqPass,
]
else:
self.passes = passes
Expand Down
88 changes: 88 additions & 0 deletions backends/xnnpack/_passes/replace_u8_convert_with_dq_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import PassResult


class ReplaceU8ConvertWithDqPass(XNNPACKPass):
"""
Support for U8 tensors in the XNNPACK delegate is done by treating U8
tensors as asymmetric quantized U8 tensors with a zero-point of zero and
a scale of 1. To handle convert ops from U8 to F32, conversion is replaced
with a dequantize operation. This pass is responsible for perfoming this
replacement.
"""

@staticmethod
def can_replace_to_copy_node(node: torch.fx.Node):
"""
Returns true if the _to_copy node can be replaced with a dequantize
operation. This is possible if the input dtype is u8, the output dtype
is f32, and the dim order is not changed.
"""
if node.op != "call_function" or node.target not in [
exir_ops.edge.aten._to_copy.default,
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
]:
return False

if node.kwargs.get("dtype", None) != torch.float:
return False

input_node = node.args[0]
if (
not isinstance(input_node, torch.fx.Node)
or "val" not in input_node.meta
or input_node.meta["val"].dtype != torch.uint8
):
return False

if node.target == exir_ops.edge.aten._to_copy.default:
# TODO Don't don't assume channels_first?
if node.kwargs.get("memory_format", torch.preserve_format) not in [
torch.preserve_format,
torch.contiguous_format,
]:
return False
elif node.target == exir_ops.edge.dim_order_ops._to_dim_order_copy.default:
default_dim_order = list(range(len(node.meta["val"].shape)))
if node.kwargs.get("dim_order", default_dim_order) != default_dim_order:
return False

return True

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph = graph_module.graph
node_list = list(graph.nodes)
for node in node_list:
if (
node.op == "call_function"
and node.target == exir_ops.edge.aten._to_copy.default
):
if not ReplaceU8ConvertWithDqPass.can_replace_to_copy_node(node):
continue

with graph.inserting_before(node):
dq_node = graph.call_function(
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
(
node.args[0], # Tensor
1.0, # Scale
0, # Zero point
0, # Qmin
255, # Qmax
torch.uint8, # Dtype
),
)

node.replace_all_uses_with(dq_node)
graph.erase_node(node)

graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, True)
36 changes: 30 additions & 6 deletions backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def gen_ids_and_flags(

return ext_id, id_out, flag

def get_serialized_dtype(
def get_serialized_dtype( # noqa: 14
self,
quant_params: Optional[QuantParams],
node: torch.fx.Node,
Expand Down Expand Up @@ -254,11 +254,12 @@ def get_per_channel_dtype(
if quant_params.per_channel:
dtype = get_per_channel_dtype(quant_params)
else:
dtype = (
XNNDatatype.xnn_datatype_qint32
if quant_params.dtype == torch.int32
else XNNDatatype.xnn_datatype_qint8
)
if quant_params.dtype == torch.int32:
dtype = XNNDatatype.xnn_datatype_qint32
elif quant_params.dtype == torch.uint8:
dtype = XNNDatatype.xnn_datatype_quint8
else:
dtype = XNNDatatype.xnn_datatype_qint8
else:
node_dtype = get_node_dtype(node)
if node_dtype is not None and node_dtype == torch.float16:
Expand Down Expand Up @@ -337,6 +338,20 @@ def _check_per_channel_group_params(
# For now group quantization is only supported for 4b weights
assert quant_params.is_qc4w, "Only 4b group quantization is supported"

def _create_qparams_for_u8(self, tensor: torch.fx.Node) -> QuantParams:
return QuantParams(
per_channel=False,
q_input=tensor,
scale=1.0,
zp=0,
axis=0,
dtype=torch.uint8,
qmin=0,
qmax=255,
is_output=self.is_graph_output(tensor),
is_input=self.is_graph_input(tensor),
)

def define_tensor( # noqa: C901
self,
tensor: torch.fx.Node,
Expand Down Expand Up @@ -385,6 +400,15 @@ def define_tensor( # noqa: C901
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False
)

# Support U8 tensors by treating them as asymmetric quantized tensors with
# scale=1 and zero_point=0.
if (
"val" in tensor.meta
and tensor.meta["val"].dtype == torch.uint8
and quant_params is None
):
quant_params = self._create_qparams_for_u8(tensor)

# Get new xnn id for tensor value
ext_id, id_out, flag = self.gen_ids_and_flags(tensor, xnn_graph, quant_params)
dims = get_shape(tensor)
Expand Down
2 changes: 2 additions & 0 deletions backends/xnnpack/partition/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
SoftmaxConfig,
SquareRootConfig,
SubConfig,
ToDimOrderCopyConfig,
UpsampleBilinear2dConfig,
)
from executorch.backends.xnnpack.partition.config.node_configs import (
Expand Down Expand Up @@ -101,6 +102,7 @@
SoftmaxConfig,
SquareRootConfig,
SubConfig,
ToDimOrderCopyConfig,
UpsampleBilinear2dConfig,
# Quant/Dequant Op Configs
QuantizedPerTensorConfig,
Expand Down
37 changes: 34 additions & 3 deletions backends/xnnpack/partition/config/generic_node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from typing import cast, List, Optional

import torch
from executorch.backends.xnnpack._passes.replace_u8_convert_with_dq_pass import (
ReplaceU8ConvertWithDqPass,
)
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
ConfigPrecisionType,
XNNPartitionerConfig,
Expand Down Expand Up @@ -191,7 +194,11 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
return True

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]
return [
ConfigPrecisionType.FP32,
ConfigPrecisionType.U8,
ConfigPrecisionType.STATIC_QUANT,
]


class CeilConfig(GenericNodePartitionerConfig):
Expand Down Expand Up @@ -330,7 +337,7 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
return True

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]
return [ConfigPrecisionType.FP32, ConfigPrecisionType.U8]

def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
return torch.ops.aten.upsample_bilinear2d.vec
Expand Down Expand Up @@ -472,7 +479,11 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
return True

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]
return [
ConfigPrecisionType.FP32,
ConfigPrecisionType.U8,
ConfigPrecisionType.STATIC_QUANT,
]


class SquareRootConfig(GenericNodePartitionerConfig):
Expand Down Expand Up @@ -543,3 +554,23 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]


class ToDimOrderCopyConfig(GenericNodePartitionerConfig):
target_name = "_to_dim_order_copy.default"

def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
if not self.check_common_constraints(node, ep):
return False

if not ReplaceU8ConvertWithDqPass.can_replace_to_copy_node(node):
why(node, reason="Only u8 to f32 conversion is supported")
return False

return True

def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
return torch.ops.aten.scaled_dot_product_attention.default

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.U8]
11 changes: 11 additions & 0 deletions backends/xnnpack/partition/config/xnnpack_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class ConfigPrecisionType(Enum):
FP32 = 1
STATIC_QUANT = 2
DYNAMIC_QUANT = 3
U8 = 4


class XNNPartitionerConfig(PartitionerConfig):
Expand Down Expand Up @@ -170,6 +171,9 @@ def _check_inputs_are_valid_dtypes(self, node, valid_dtypes):
return False

if arg_val.dtype not in valid_dtypes:
logger.warn(
f"Input {node} has invalid dtype {arg_val.dtype} ({valid_dtypes})"
)
return False

return True
Expand All @@ -188,6 +192,9 @@ def _check_outputs_are_valid_dtypes(self, node, valid_dtypes):
return False

if val.dtype not in valid_dtypes:
logger.warn(
f"Output {node} has invalid dtype {val.dtype} ({valid_dtypes})"
)
return False

return True
Expand All @@ -199,6 +206,10 @@ def _check_node_has_valid_dtype(self, node):
torch.int8,
torch.qint8,
}

if ConfigPrecisionType.U8 in self.enabled_precision_types:
valid_dtypes.add(torch.uint8)

if (
node.op != "placeholder"
and node.op != "call_function"
Expand Down
63 changes: 63 additions & 0 deletions backends/xnnpack/test/models/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,45 @@

import unittest

from typing import Sequence

import torch
import torchvision
from executorch.backends.xnnpack.test.tester import Tester
from executorch.backends.xnnpack.test.tester.tester import Quantize
from torchvision import models


class ResizeAndCropWrapper(torch.nn.Module):
def __init__(
self,
model: torch.nn.Module,
resize_shape: Sequence[int],
crop_shape: Sequence[int],
):
super().__init__()
self.resize_shape = resize_shape
self.crop = torchvision.transforms.CenterCrop(crop_shape)
# Simplified ImageNet normalization expected by pre-trained weights
self.normalize_mean = 0.456
self.normalize_std = 0.225
self.model = model

def forward(self, image):
resized = torch.nn.functional.interpolate(
image,
size=self.resize_shape,
mode="bilinear",
align_corners=False,
antialias=False,
)
cropped = self.crop(resized)
image_f32 = cropped.to(torch.float) / 255.0
normalized = (image_f32 - self.normalize_mean) / self.normalize_std

return self.model(normalized / 255.0)


class TestMobileNetV3(unittest.TestCase):
mv3 = models.mobilenetv3.mobilenet_v3_small(pretrained=True)
mv3 = mv3.eval()
Expand All @@ -34,6 +67,8 @@ class TestMobileNetV3(unittest.TestCase):
"executorch_exir_dialects_edge__ops_aten_mul_Tensor",
"executorch_exir_dialects_edge__ops_aten_div_Tensor",
"executorch_exir_dialects_edge__ops_aten_mean_dim",
"executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor",
"executorch_exir_dialects_edge__ops_aten_upsample_bilinear2d_vec",
}

def test_fp32_mv3(self):
Expand All @@ -48,6 +83,34 @@ def test_fp32_mv3(self):
.run_method_and_compare_outputs(num_runs=5)
)

def test_fp32_mv3_with_u8_resize(self):
dynamic_shapes = (
{
2: torch.export.Dim("height", min=260, max=1024),
3: torch.export.Dim("width", min=260, max=1024),
},
)
wrapped_model = ResizeAndCropWrapper(
self.mv3,
(260, 260),
(224, 224),
)
u8_inputs = (torch.randint(0, 255, (1, 3, 512, 512)).to(torch.uint8),)
(
Tester(wrapped_model, u8_inputs, dynamic_shapes=dynamic_shapes)
.export()
.dump_artifact()
.to_edge_transform_and_lower()
.dump_artifact()
.check(["torch.ops.higher_order.executorch_call_delegate"])
.check_not(list(self.all_operators))
.to_executorch()
.serialize()
# XNN u8 reshape can differ by 1 from eager mode, leading to a very
# small increase in tolerance.
.run_method_and_compare_outputs(num_runs=5, atol=0.002)
)

@unittest.skip("T187799178: Debugging Numerical Issues with Calibration")
def _test_qs8_mv3(self):
ops_after_lowering = self.all_operators
Expand Down
Loading
Loading