-
Notifications
You must be signed in to change notification settings - Fork 259
[Torch FX] Compress PT2E Support #3663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 1 commit
190f9d5
c52fcca
4e56cb5
9651ceb
14daeb5
3746815
0815dc5
1b8d940
7d35374
427ebc2
24dbfb6
4bb8c1a
8902842
3842538
2e70c2e
b1c9aad
33fe01c
d8e1006
88a8472
7a8e51a
fed5052
2866473
7171d56
3e3b067
5b7b210
71a479f
b24a59c
d12225a
9870ee2
8015629
0804218
1f1fda3
623ce46
d14a6eb
e91b455
448bf84
8e23572
36ddf53
07b730b
d5dd422
076a76b
2ce9eec
1bebf3e
ea81cfd
e82920f
82cc10b
beae508
8bd95df
aac9d3f
4278cfd
6fd5216
118b611
e9f3cd4
a969e58
71d0597
bf671ff
5f1c2de
6f81879
eb0ff16
8afeb9d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -799,7 +799,7 @@ def _collect_statistics_and_statistic_points( | |
|
||
:param model: Backend-specific model instance. | ||
:param graph: Corresponding NNCFGraph of the model. | ||
:param statistic_points: Container with pre-collected statistics, if available. | ||
:param statistic_points: Statistic points. | ||
|
||
:param dataset: Dataset used for collecting statistics when not provided. | ||
:param ratio_defining_params: List of parameters defining compression ratios. | ||
:param all_weight_params: List of all weight compression parameters. | ||
|
@@ -832,18 +832,20 @@ def get_weight_compression_parameters( | |
list[WeightCompressionParameters], | ||
]: | ||
""" | ||
Generates a list of weight compression parameters based on the Weight Compression algorithm | ||
configuration. Determines the appropriate quantization parameters for each node eligible for | ||
weight compression. Also, returns a list of ratio defining parameters which are a subset of | ||
all_weight_parameters. This is based on parameters like all_layers. Lastly, it gives a list | ||
of skipped layers based on parameters like ignored scope or depending on the group size value | ||
adjustment. | ||
This Function: | ||
1. Generates a list of weight compression parameters based on the algorithm configuration. | ||
2. Determines the appropriate quantization parameters for each node eligible for weight compression. | ||
3. Generates a subset of parameters that can be compressed in both primary and backup precisions, | ||
called ratio-defining parameters. All ratio-defining parameters are set to the primary precision. | ||
4. Generates a subset of parameters that will not be compressed, based on the ignored scope or | ||
compression configuration restrictions. | ||
|
||
:param model: Backend-specific input model. | ||
:param graph: NNCFGraph instance. | ||
:return: A tuple consisting of a list of all weight compression parameters, based on the Weight | ||
Compression algorithm configuration, list of ratio defining parameters(weights that are used | ||
for ratio calculation between primary and backup precisions), and list of weight parameters to skip. | ||
:return: A tuple consisting a list of weight compression parameters that can be compressed, | ||
a list of ratio-defining parameters, which is a subset of compressible weight parameters | ||
that are allowed to be set to mixed precisions, and a list of weight compression parameters | ||
that can not be compressed. | ||
|
||
""" | ||
nodes_to_compress = self.get_nodes_to_compress(graph) | ||
|
||
|
@@ -916,7 +918,7 @@ def get_weight_compression_parameters( | |
else: | ||
group_size_values = {w_params.weight_name: self._group_size for w_params in ratio_defining_params} | ||
|
||
# Set these layers to primary config. Later we will set layers to backup precision according to Mixed precision | ||
# Set each ratio defining parameter to primary config | ||
for weight_param in ratio_defining_params: | ||
weight_param.compression_config = self._get_primary_config(group_size_values[weight_param.weight_name]) | ||
|
||
|
@@ -954,7 +956,23 @@ def apply_with_parameters( | |
all_weight_params: list[WeightCompressionParameters], | ||
ratio_defining_params: list[WeightCompressionParameters], | ||
skipped_weight_params: list[WeightCompressionParameters], | ||
): | ||
) -> TModel: | ||
""" | ||
Applies the Weight Compression algorithm using precomputed parameters and optional | ||
algorithms (AWQ, GPTQ, scale estimation, LoRA correction). The method collects | ||
statistics, configures the weight compression parameters for mixed precision algorithm, | ||
and performs the model transformation with appropriate decompression operations | ||
|
||
:param model: Backend-specific model to be compressed. | ||
:param graph: NNCFGraph instance. | ||
:param dataset: Dataset to collect statistics. | ||
:param statistic_points: Statistics points object. | ||
:param all_weight_params: List of all weight parameters. | ||
:param ratio_defining_params: Subset of all_weight_params that determine mixed-precision ratios. | ||
:param skipped_weight_params: List of parameters corresponding to weights intentionally skipped | ||
from compression (e.g., due to ignored scopes or group size adjustments). | ||
:return: Transformed model with compressed weights and inserted backend-specific decompressor. | ||
""" | ||
# Collect statistics for the weights compression | ||
statistics, statistic_points = self._collect_statistics_and_statistic_points( | ||
model, graph, statistic_points, dataset, ratio_defining_params, all_weight_params | ||
|
@@ -1179,7 +1197,7 @@ def _get_statistics_for_weights_compression( | |
|
||
:param matmul_input_to_output_nodes_map: A mapping from activation node and a port id to corresponding matmul | ||
nodes which accept this activation as an input. | ||
:param statistic_points: Statistic points object. | ||
:param statistic_points: Statistic points. | ||
:return: Collected statistics. | ||
""" | ||
# For each node we store statistics in a WCTensorStatistics data-class. It contains the following fields: | ||
|
Uh oh!
There was an error while loading. Please reload this page.