-
Notifications
You must be signed in to change notification settings - Fork 259
[Torch FX] Compress PT2E Support #3663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 32 commits
190f9d5
c52fcca
4e56cb5
9651ceb
14daeb5
3746815
0815dc5
1b8d940
7d35374
427ebc2
24dbfb6
4bb8c1a
8902842
3842538
2e70c2e
b1c9aad
33fe01c
d8e1006
88a8472
7a8e51a
fed5052
2866473
7171d56
3e3b067
5b7b210
71a479f
b24a59c
d12225a
9870ee2
8015629
0804218
1f1fda3
623ce46
d14a6eb
e91b455
448bf84
8e23572
36ddf53
07b730b
d5dd422
076a76b
2ce9eec
1bebf3e
ea81cfd
e82920f
82cc10b
beae508
8bd95df
aac9d3f
4278cfd
6fd5216
118b611
e9f3cd4
a969e58
71d0597
bf671ff
5f1c2de
6f81879
eb0ff16
8afeb9d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Copyright (c) 2025 Intel Corporation | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright (c) 2025 Intel Corporation | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
|
||
import nncf | ||
from nncf import SensitivityMetric | ||
from nncf.common.graph.graph import NNCFGraph | ||
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer | ||
from nncf.common.utils.backend import BackendType | ||
from nncf.quantization.algorithms.algorithm import Algorithm | ||
from nncf.quantization.algorithms.weight_compression.algorithm import WeightCompression | ||
|
||
|
||
class WeightsCompressionPT2E(Algorithm): | ||
daniil-lyakhov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
def __init__( | ||
self, | ||
quantizer, | ||
subset_size: int = 128, | ||
awq: bool = False, | ||
scale_estimation: bool = False, | ||
gptq: bool = False, | ||
lora_correction: bool = False, | ||
sensitivity_metric: nncf.SensitivityMetric = None, | ||
compression_format: nncf.CompressionFormat = nncf.CompressionFormat.DQ, | ||
advanced_parameters: nncf.AdvancedCompressionParameters = None, | ||
) -> torch.fx.GraphModule: | ||
self._quantizer = quantizer | ||
Comment on lines
57
to
96
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typehints an docstring are missing |
||
|
||
wc_config = self._quantizer.get_weight_compression_config() | ||
|
||
mode = wc_config.get("mode", None) | ||
ratio = wc_config.get("ratio", 1) | ||
group_size = wc_config.get("group_size", 128) | ||
all_layers = wc_config.get("all_layers", False) | ||
backup_mode = wc_config.get("backup_mode", nncf.BackupMode.INT8_ASYM) | ||
self._sensitivity_metric = sensitivity_metric | ||
self._algo = WeightCompression( | ||
mode=mode, | ||
ratio=ratio, | ||
group_size=group_size, | ||
ignored_scope=nncf.IgnoredScope(), # only compress "nodes_to_compress" | ||
daniil-lyakhov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
all_layers=all_layers, | ||
sensitivity_metric=self._sensitivity_metric or SensitivityMetric.WEIGHT_QUANTIZATION_ERROR, | ||
awq=awq, | ||
subset_size=subset_size, | ||
scale_estimation=scale_estimation, | ||
gptq=gptq, | ||
lora_correction=lora_correction, | ||
backup_mode=backup_mode, | ||
compression_format=compression_format, | ||
advanced_parameters=advanced_parameters, | ||
) | ||
|
||
def available_backends(self) -> list[BackendType]: | ||
return self._algo.available_backends() | ||
|
||
def apply( | ||
self, | ||
model: torch.fx.GraphModule, | ||
graph: NNCFGraph, | ||
statistic_points=None, | ||
dataset=None, | ||
): | ||
self._algo.set_backend_entity(model) # Set algo backend | ||
|
||
if self._sensitivity_metric is None: | ||
# Default case. It means that it is not defined by the user in the API | ||
# Hence, the annotation(Quantization parameters for all layers) from the quantizer will be used. | ||
all_weight_params = self._quantizer.get_weight_compression_setup( | ||
model, graph | ||
) # Get weight compression params FROM QUANTIZER | ||
statistics, statistic_points = self._algo.collect_weight_compression_statistics( | ||
model, graph, dataset, all_weight_params, statistic_points | ||
) | ||
else: | ||
# Data Aware mixed precision is used. In this case, only nodes_to_compress is obtained from the quantizer | ||
nodes_to_compress = self._quantizer.get_nodes_to_compress( | ||
model, graph | ||
) # Get nodes to compress FROM QUANTIZER | ||
all_weight_params, statistics = self._algo.get_weight_compression_parameters( | ||
model, graph, nodes_to_compress, statistic_points, dataset | ||
) | ||
daniil-lyakhov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
transformed_model = self._algo.apply_wc_algos( | ||
model, graph, all_weight_params, statistics, dataset | ||
) # Apply the wc algos FROM ALGO | ||
return transformed_model | ||
|
||
def get_statistic_points(self, model, graph: NNCFGraph) -> StatisticPointsContainer: | ||
return self._algo.get_statistic_points(model, graph) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
from nncf.common.logging import nncf_logger | ||
from nncf.common.utils.api_marker import api | ||
from nncf.experimental.quantization.algorithms.post_training.algorithm import ExperimentalPostTrainingQuantization | ||
from nncf.experimental.quantization.algorithms.weight_compression.algorithm import WeightsCompressionPT2E | ||
from nncf.experimental.torch.fx.constant_folding import constant_fold | ||
from nncf.experimental.torch.fx.quantization.quantizer.openvino_adapter import OpenVINOQuantizerAdapter | ||
from nncf.experimental.torch.fx.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer | ||
|
@@ -157,3 +158,63 @@ def _quant_node_constraint(n: torch.fx.Node) -> bool: | |
related to quantization | ||
""" | ||
return n.op == "call_function" and n.target in QUANTIZE_NODE_TARGETS | ||
|
||
|
||
@api(canonical_alias="nncf.experimental.torch.fx.compress_pt2e") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please check that API docs reflect the new API correctly There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you mean to ask about the method docstring or is there another API doc? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://openvinotoolkit.github.io/nncf/autoapi/nncf/ |
||
def compress_pt2e( | ||
model: torch.fx.GraphModule, | ||
quantizer: Quantizer, | ||
dataset: Optional[nncf.Dataset] = None, | ||
awq: bool = False, | ||
scale_estimation: bool = False, | ||
gptq: bool = False, | ||
lora_correction: bool = False, | ||
subset_size: int = 128, # Dataset size to use | ||
daniil-lyakhov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
sensitivity_metric: nncf.SensitivityMetric = None, | ||
advanced_parameters: nncf.AdvancedCompressionParameters = None, | ||
) -> torch.fx.GraphModule: | ||
""" | ||
Applies Weight Compression to the torch.fx.GraphModule provided model | ||
using provided torch.ao quantizer. | ||
|
||
:param model: A torch.fx.GraphModule instance to be quantized. | ||
:param quantizer: Torch ao quantizer to annotate nodes in the graph with quantization setups | ||
to convey the desired way of quantization. | ||
:param dataset: A representative dataset for the | ||
calibration process. | ||
:param awq: Determines whether to use or not the modified AWQ algorithm. | ||
:param scale_estimation: Determines whether to use or not scale estimation for 4-bit layers. | ||
:param gptq: Determines whether to use or not GPTQ algorithm. | ||
:param lora_correction: Determines whether to use or not LoRA Correction algorithm. | ||
:param subset_size: Number of data samples to calculate activation statistics used for assigning different | ||
quantization precision. | ||
:param sensitivity_metric: The sensitivity metric for assigning quantization precision to layers. In order to | ||
preserve the accuracy of the model, the more sensitive layers receive a higher precision. | ||
:param advanced_parameters: Advanced parameters for algorithms in the compression pipeline. | ||
""" | ||
if isinstance(quantizer, OpenVINOQuantizer) or hasattr(quantizer, "get_nncf_weight_compression_setup"): | ||
quantizer = OpenVINOQuantizerAdapter(quantizer) | ||
compression_format = nncf.CompressionFormat.DQ | ||
else: | ||
# TODO Support Third party quantizers here. | ||
msg = "Only OpenVINO Quantizer is supported currently." | ||
raise nncf.InternalError(msg) | ||
|
||
quantization_algorithm = WeightsCompressionPT2E( | ||
quantizer=quantizer, | ||
awq=awq, | ||
subset_size=subset_size, | ||
scale_estimation=scale_estimation, | ||
gptq=gptq, | ||
lora_correction=lora_correction, | ||
sensitivity_metric=sensitivity_metric, | ||
compression_format=compression_format, | ||
advanced_parameters=advanced_parameters, | ||
) | ||
|
||
# Here the model is annotated | ||
transformed_model = quantizer.transform_prior_quantization(model) | ||
nncf_graph = NNCFGraphFactory.create(transformed_model) | ||
quantized_model = quantization_algorithm.apply(transformed_model, nncf_graph, dataset=dataset) | ||
quantized_model = torch.fx.GraphModule(quantized_model, graph=quantized_model.graph) | ||
return quantized_model |
Uh oh!
There was an error while loading. Please reload this page.