-
Notifications
You must be signed in to change notification settings - Fork 259
[TorchFX] Use torchao for quantize_pt2e API when possible #3588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright (c) 2025 Intel Corporation | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from dataclasses import dataclass | ||
|
||
import torch | ||
|
||
|
||
@dataclass | ||
class TorchQDQParameters: | ||
""" | ||
Stores the quantization parameters required for | ||
creation of a PyTorch quantize-dequantize pair. | ||
:param quant_min: Minimum quant value. | ||
:type quant_min: int | ||
:param quant_max: Maximum quant value. | ||
:type quant_max: int | ||
:param scale: Defines the scale factor used for quantization. | ||
:type scale: torch.Tensor | ||
:param zero_point: Specifies the quantized value to which 0 in floating point maps to. | ||
:type zero_point: torch.Tensor | ||
:param is_per_channel: Whether quantization is applied per channel. | ||
:type is_per_channel: bool | ||
:param ch_axis: Channel axis used for per-channel quantization. | ||
:type ch_axis: int | ||
""" | ||
|
||
quant_min: int | ||
quant_max: int | ||
scale: torch.Tensor | ||
zero_point: torch.Tensor | ||
is_per_channel: bool | ||
ch_axis: int |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -15,12 +15,9 @@ | |||
|
||||
import torch | ||||
import torch.fx | ||||
from torch.ao.quantization.fx.utils import create_getattr_from_value | ||||
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_ | ||||
from torch.fx.node import map_arg | ||||
from torch.fx.passes.infra.pass_base import PassBase | ||||
from torch.fx.passes.infra.pass_base import PassResult | ||||
from torch.quantization.fake_quantize import FakeQuantize | ||||
|
||||
import nncf | ||||
import nncf.torch | ||||
|
@@ -29,6 +26,7 @@ | |||
from nncf.experimental.torch.fx.constant_folding import constant_fold | ||||
from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name | ||||
from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node | ||||
from nncf.experimental.torch.fx.quantization.qdq_parameters import TorchQDQParameters | ||||
from nncf.torch.graph.transformations.commands import PTTargetPoint | ||||
|
||||
TransformationFNType = Callable[[torch.fx.GraphModule], None] | ||||
|
@@ -223,16 +221,16 @@ def constant_update_fn( | |||
|
||||
|
||||
def qdq_insertion_transformation_builder( | ||||
quantizer: FakeQuantize, target_points: list[PTTargetPoint] | ||||
parameters: TorchQDQParameters, target_points: list[PTTargetPoint] | ||||
) -> TransformationFNType: | ||||
""" | ||||
Returns transformation which inserts quantize-dequantize operations with parameters | ||||
inherited from the given quantizer to each given target point. | ||||
Returns transformation which inserts quantize-dequantize operations with | ||||
the given parameters to each given target point. | ||||
|
||||
:param quantizer: Quantizer module to inherit quantization parameters from. | ||||
:param quantizer: Quantization parameters. | ||||
:param target_points: List of target point used to insert quantize-dequantize pairs. | ||||
:return: Transformation which inserts quantize-dequantize operations with parameters | ||||
inherited from the given quantizer to each given target point. | ||||
:return: Transformation which inserts quantize-dequantize operations with | ||||
the given parameters to each given target point. | ||||
""" | ||||
|
||||
def qdq_insertion_transformation(model: torch.fx.GraphModule): | ||||
|
@@ -243,7 +241,7 @@ def qdq_insertion_transformation(model: torch.fx.GraphModule): | |||
) | ||||
raise nncf.InternalError(msg) | ||||
for target_point in target_points: | ||||
insert_one_qdq(model, target_point, quantizer) | ||||
insert_one_qdq(model, target_point, parameters) | ||||
|
||||
return qdq_insertion_transformation | ||||
|
||||
|
@@ -311,38 +309,38 @@ def output_insertion_transformation(model: torch.fx.GraphModule): | |||
return output_insertion_transformation | ||||
|
||||
|
||||
def insert_one_qdq(model: torch.fx.GraphModule, target_point: PTTargetPoint, quantizer: FakeQuantize): | ||||
def insert_one_qdq(model: torch.fx.GraphModule, target_point: PTTargetPoint, parameters: TorchQDQParameters): | ||||
""" | ||||
Inserts quantize-dequantize after the target node to the target model. | ||||
|
||||
:param model: Target model. | ||||
:param target_node: Target node, quantizer-dequantizer pair is inserted just after the | ||||
target node. | ||||
:param quantizer: Quantizer module to inherit quantization parameters from. | ||||
:param parameters: Quantization parameters. | ||||
""" | ||||
# Copied from torch.ao.quantization.quantize_pt2e.convert_pt2e | ||||
# Copied from torchao.quantization.quantize_pt2e.convert_pt2e | ||||
# 1. extract information for inserting q/dq node from activation_post_process | ||||
node_type = "call_function" | ||||
quantize_op: Optional[Callable] = None | ||||
|
||||
dtype = torch.int8 if quantizer.quant_min < 0 else torch.uint8 | ||||
if quantizer.is_per_channel: | ||||
dtype = torch.int8 if parameters.quant_min < 0 else torch.uint8 | ||||
if parameters.is_per_channel: | ||||
qparams = { | ||||
"_scale_": quantizer.scale, | ||||
"_zero_point_": quantizer.zero_point, | ||||
"_axis_": quantizer.ch_axis, | ||||
"_quant_min_": quantizer.quant_min, | ||||
"_quant_max_": quantizer.quant_max, | ||||
"_scale_": parameters.scale, | ||||
"_zero_point_": parameters.zero_point, | ||||
"_axis_": parameters.ch_axis, | ||||
"_quant_min_": parameters.quant_min, | ||||
"_quant_max_": parameters.quant_max, | ||||
"_dtype_": dtype, | ||||
} | ||||
quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default | ||||
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel.default | ||||
else: | ||||
qparams = { | ||||
"_scale_": float(quantizer.scale), | ||||
"_zero_point_": int(quantizer.zero_point), | ||||
"_quant_min_": quantizer.quant_min, | ||||
"_quant_max_": quantizer.quant_max, | ||||
"_scale_": float(parameters.scale), | ||||
"_zero_point_": int(parameters.zero_point), | ||||
"_quant_min_": parameters.quant_min, | ||||
"_quant_max_": parameters.quant_max, | ||||
"_dtype_": dtype, | ||||
} | ||||
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default | ||||
|
@@ -721,19 +719,6 @@ def match_filters(match, original_graph, graph): | |||
_set_meta_for_matches(model, matches) | ||||
|
||||
|
||||
def apply_quantization_transformations(model: torch.fx.GraphModule) -> None: | ||||
""" | ||||
Applies quantization transformations to the model. | ||||
|
||||
:param model: Model to apply transformations to. | ||||
""" | ||||
# BatchNorm operations have 3 output ports, | ||||
# to make it easier for algorithms to work | ||||
# with the target graph BatchNorm operations | ||||
# are being fused | ||||
_fuse_conv_bn_(model) | ||||
|
||||
|
||||
def fold_constant_except_qdq(model: torch.fx.GraphModule): | ||||
""" | ||||
Performs constant folding avoiding quantize-dequantize pattern. | ||||
|
@@ -826,3 +811,52 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: | |||
graph_module.graph.eliminate_dead_code() | ||||
graph_module.recompile() | ||||
return PassResult(graph_module, True) | ||||
|
||||
|
||||
def get_device(module: torch.nn.Module) -> torch.device: | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please reuse Line 416 in cc935e4
|
||||
""" | ||||
Retrieves device of the first parameter of the given module. | ||||
If there are no parameters - returns CPU device. | ||||
|
||||
:param module: A torch.nn.Module instance. | ||||
:return: A device of the first parameter of the given module. | ||||
If there are no parameters - returns CPU device. | ||||
""" | ||||
try: | ||||
named_param = next(module.parameters()) | ||||
except StopIteration: | ||||
named_param = None | ||||
if named_param is None: | ||||
return torch.device("cpu") | ||||
return named_param.device | ||||
|
||||
|
||||
def create_getattr_from_value(module: torch.nn.Module, graph: torch.fx.Graph, prefix: str, value: Any) -> torch.fx.Node: | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not found where value is not a torch.Tensor, is it really need to use Any? |
||||
""" | ||||
Given a value of any type, creates a getattr node corresponding to the value and | ||||
registers the value as a buffer to the module. | ||||
|
||||
:param module: A torch.nn.Module instance. | ||||
:param graph: A torch.fx.Graph instance. | ||||
:param prefix: A string to use as a name prefix for the new getattr node. | ||||
:param value: A value | ||||
:return: A getattr node corresponding to the given value. | ||||
""" | ||||
|
||||
def get_new_attr_name(module: torch.nn.Module, prefix: str): | ||||
def get_attr_name(i: int): | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||
return prefix + str(i) | ||||
|
||||
i = 0 | ||||
attr_name = get_attr_name(i) | ||||
while hasattr(module, attr_name): | ||||
i += 1 | ||||
attr_name = get_attr_name(i) | ||||
return attr_name | ||||
|
||||
attr_name = get_new_attr_name(module, prefix.replace(".", "_")) | ||||
device = get_device(module) | ||||
new_value = value.detach().clone() if isinstance(value, torch.Tensor) else torch.tensor(value, device=device) | ||||
module.register_buffer(attr_name, new_value) | ||||
attr_node = graph.create_node("get_attr", attr_name) | ||||
return attr_node |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:type
in docstring used only for API objects