-
Notifications
You must be signed in to change notification settings - Fork 259
[Torch FX] Compress PT2E Support #3663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 36 commits
190f9d5
c52fcca
4e56cb5
9651ceb
14daeb5
3746815
0815dc5
1b8d940
7d35374
427ebc2
24dbfb6
4bb8c1a
8902842
3842538
2e70c2e
b1c9aad
33fe01c
d8e1006
88a8472
7a8e51a
fed5052
2866473
7171d56
3e3b067
5b7b210
71a479f
b24a59c
d12225a
9870ee2
8015629
0804218
1f1fda3
623ce46
d14a6eb
e91b455
448bf84
8e23572
36ddf53
07b730b
d5dd422
076a76b
2ce9eec
1bebf3e
ea81cfd
e82920f
82cc10b
beae508
8bd95df
aac9d3f
4278cfd
6fd5216
118b611
e9f3cd4
a969e58
71d0597
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Copyright (c) 2025 Intel Corporation | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,98 @@ | ||||||
# Copyright (c) 2025 Intel Corporation | ||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
# you may not use this file except in compliance with the License. | ||||||
# You may obtain a copy of the License at | ||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||
# Unless required by applicable law or agreed to in writing, software | ||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
|
||||||
import torch | ||||||
|
||||||
import nncf | ||||||
from nncf import SensitivityMetric | ||||||
from nncf.common.graph.graph import NNCFGraph | ||||||
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer | ||||||
from nncf.common.utils.backend import BackendType | ||||||
from nncf.quantization.algorithms.algorithm import Algorithm | ||||||
from nncf.quantization.algorithms.weight_compression.algorithm import WeightCompression | ||||||
|
||||||
|
||||||
class WeightsCompressionPT2E(Algorithm): | ||||||
|
class WeightsCompressionPT2E(Algorithm): | |
class WeightCompression(Algorithm): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I rename it to ExperimentalWeightCompression
instead? since it could be confused with the original
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is inside the experimental directory, that should be descriptive enough. I suggest the WeightCompression
name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doner
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typehints an docstring are missing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -102,7 +102,7 @@ def get_weight_compression_configuration( | |
) | ||
|
||
return { | ||
"mode": mode, | ||
"mode": mode if isinstance(mode, nncf.CompressWeightsMode) else nncf.CompressWeightsMode(mode), | ||
"ratio": ratio or 1, | ||
"group_size": group_size, | ||
"all_layers": all_layers or False, | ||
|
@@ -527,11 +527,8 @@ def _set_weight_compression_config( | |
primary_precision_weight_params = self._mixed_precision_algo.apply( | ||
model, graph, statistics_points, weight_params=ratio_defining_params | ||
) | ||
else: | ||
primary_precision_weight_params = ratio_defining_params | ||
|
||
for weight_param in primary_precision_weight_params: | ||
weight_param.compression_config = self._get_primary_config(group_size_values[weight_param.weight_name]) | ||
for weight_param in primary_precision_weight_params: | ||
weight_param.compression_config = self._get_primary_config(group_size_values[weight_param.weight_name]) | ||
|
||
# Check if group size is valid for each weight in ratio_defining_params | ||
failed_nodes = [] | ||
|
@@ -769,12 +766,32 @@ def is_weight_compression_supported( | |
|
||
return is_supported_dtype and not no_bit_reduction | ||
|
||
def _collect_statistics_and_statistic_points( | ||
self, model, graph, statistic_points, dataset, ratio_defining_params, all_weight_params | ||
): | ||
if not dataset or not (self._data_aware_mixed_precision or self._data_aware_compression): | ||
return None, statistic_points | ||
weight_params = ratio_defining_params if self._backup_mode == BackupMode.NONE else all_weight_params | ||
matmul_nodes_to_compress = [ | ||
wp.node_with_weight | ||
for wp in weight_params | ||
if wp.node_with_weight.metatype in self._backend_entity.matmul_metatypes | ||
] | ||
matmul_input_to_output_nodes_map = self.get_matmul_input_to_output_nodes_map(matmul_nodes_to_compress, graph) | ||
if statistic_points is None: | ||
statistic_points = self.get_statistic_points(model, graph, matmul_input_to_output_nodes_map.keys()) | ||
statistics_aggregator = StatisticsAggregatorFactory.create(model, dataset) | ||
statistics_aggregator.register_statistic_points(statistic_points) | ||
statistics_aggregator.collect_statistics(model, graph) | ||
statistic_points = statistics_aggregator.statistic_points | ||
return self._get_statistics_for_weights_compression( | ||
matmul_input_to_output_nodes_map, statistic_points | ||
), statistic_points | ||
|
||
def get_weight_compression_parameters( | ||
self, | ||
model: TModel, | ||
graph: NNCFGraph, | ||
statistic_points: Optional[StatisticPointsContainer] = None, | ||
dataset: Optional[Dataset] = None, | ||
) -> tuple[list[WeightCompressionParameters], Optional[dict[str, WCTensorStatistic]]]: | ||
""" | ||
Generates a list of weight compression parameters based on the Weight Compression algorithm | ||
|
@@ -869,37 +886,18 @@ def get_weight_compression_parameters( | |
else: | ||
group_size_values = {w_params.weight_name: self._group_size for w_params in ratio_defining_params} | ||
|
||
# Collect statistics for the weights compression | ||
statistics = None | ||
if (self._data_aware_mixed_precision or self._data_aware_compression) and dataset: | ||
weight_params = ratio_defining_params if self._backup_mode == BackupMode.NONE else all_weight_params | ||
matmul_nodes_to_compress = [ | ||
wp.node_with_weight | ||
for wp in weight_params | ||
if wp.node_with_weight.metatype in self._backend_entity.matmul_metatypes | ||
] | ||
matmul_input_to_output_nodes_map = self.get_matmul_input_to_output_nodes_map( | ||
matmul_nodes_to_compress, graph | ||
) | ||
if statistic_points is None: | ||
statistic_points = self.get_statistic_points(model, graph, matmul_input_to_output_nodes_map.keys()) | ||
statistic_points = self._collect_statistics(dataset, graph, model, statistic_points) | ||
statistics = self._get_statistics_for_weights_compression( | ||
matmul_input_to_output_nodes_map, statistic_points | ||
) | ||
|
||
# Set weight compression configuration | ||
self._set_weight_compression_config(ratio_defining_params, model, graph, statistic_points, group_size_values) | ||
|
||
# If no mixed precision has to be applied, then set the primary config for all ratio defining params. | ||
if self._ratio == 1 or len(ratio_defining_params) == 0: | ||
daniil-lyakhov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
for weight_param in ratio_defining_params: | ||
weight_param.compression_config = self._get_primary_config(group_size_values[weight_param.weight_name]) | ||
|
||
# Print statistics | ||
nncf_logger.info( | ||
self._get_bitwidth_distribution_str(all_weight_params, ratio_defining_params, skipped_weight_params) | ||
) | ||
|
||
# Filter all_weight_params and by excluding nodes that should remain in their original floating-point precision | ||
all_weight_params = list(filter(lambda w_params: w_params.compression_config is not None, all_weight_params)) | ||
return all_weight_params, ratio_defining_params, group_size_values, skipped_weight_params | ||
|
||
return all_weight_params, statistics | ||
|
||
def apply( | ||
self, | ||
|
@@ -911,7 +909,45 @@ def apply( | |
self.set_backend_entity(model) | ||
|
||
# Get processed weight compression parameters ready for compression | ||
all_weight_params, statistics = self.get_weight_compression_parameters(model, graph, statistic_points, dataset) | ||
all_weight_params, ratio_defining_params, group_size_values, skipped_weight_params = self.get_weight_compression_parameters( | ||
model, graph | ||
) | ||
return self.apply_with_parameters( | ||
model, | ||
graph, | ||
dataset, | ||
statistic_points, | ||
all_weight_params, | ||
ratio_defining_params, | ||
group_size_values, | ||
skipped_weight_params, | ||
) | ||
|
||
def apply_with_parameters( | ||
self, | ||
model, | ||
graph, | ||
dataset, | ||
statistic_points, | ||
all_weight_params, | ||
ratio_defining_params, | ||
group_size_values, | ||
skipped_weight_params, | ||
): | ||
# Collect statistics for the weights compression | ||
Comment on lines
+957
to
+958
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Description |
||
statistics, statistic_points = self._collect_statistics_and_statistic_points( | ||
model, graph, statistic_points, dataset, ratio_defining_params, all_weight_params | ||
) | ||
# Set weight compression configuration | ||
self._set_weight_compression_config(ratio_defining_params, model, graph, statistic_points, group_size_values) | ||
|
||
# Filter all_weight_params and by excluding nodes that should remain in their original floating-point precision | ||
all_weight_params = list(filter(lambda w_params: w_params.compression_config is not None, all_weight_params)) | ||
|
||
# Print statistics | ||
nncf_logger.info( | ||
self._get_bitwidth_distribution_str(all_weight_params, ratio_defining_params, skipped_weight_params) | ||
) | ||
|
||
if self._awq: | ||
model = self.awq_algo.apply(model, graph, all_weight_params, statistics, self._backend_entity) | ||
|
@@ -1048,26 +1084,6 @@ def get_compression_nodes_info( | |
matmul_input_to_output_nodes_map = self.get_matmul_input_to_output_nodes_map(matmul_nodes_to_compress, graph) | ||
return nodes_to_compress, matmul_input_to_output_nodes_map | ||
|
||
def _collect_statistics( | ||
self, | ||
dataset: Dataset, | ||
graph: NNCFGraph, | ||
model: TModel, | ||
statistic_points: StatisticPointsContainer, | ||
): | ||
""" | ||
Creates statistics aggregator, registers all statistics specified for algorithm, and then collect them. | ||
|
||
:param dataset: Dataset to collect values. | ||
:param graph: Model graph. | ||
:param model: Model for statistics collection. | ||
:param statistic_points: Statistics points. | ||
""" | ||
statistics_aggregator = StatisticsAggregatorFactory.create(model, dataset) | ||
statistics_aggregator.register_statistic_points(statistic_points) | ||
statistics_aggregator.collect_statistics(model, graph) | ||
return statistics_aggregator.statistic_points | ||
|
||
def get_statistic_points( | ||
self, | ||
model: TModel, | ||
|
@@ -1147,4 +1163,4 @@ def _get_statistics_for_weights_compression( | |
# Each activation node may have multiple MatMul nodes which it is an input to | ||
for node in matmul_nodes: | ||
statistics[node.node_name] = copy.deepcopy(stats) | ||
return statistics | ||
return statistics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, good catch! I will change it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done