From 33543c0ab7e0cf42a7f850b9b7bde57c29fbf669 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Thu, 2 Oct 2025 16:53:56 +0200 Subject: [PATCH 1/2] Split get_weight_compression_parameters on get_params/collect_statistics --- .../weight_compression/algorithm.py | 82 ++++++++----------- 1 file changed, 34 insertions(+), 48 deletions(-) diff --git a/src/nncf/quantization/algorithms/weight_compression/algorithm.py b/src/nncf/quantization/algorithms/weight_compression/algorithm.py index d0407a1eff4..202a8d57097 100644 --- a/src/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/src/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -773,8 +773,6 @@ def get_weight_compression_parameters( self, model: TModel, graph: NNCFGraph, - statistic_points: Optional[StatisticPointsContainer] = None, - dataset: Optional[Dataset] = None, ) -> tuple[list[WeightCompressionParameters], Optional[dict[str, WCTensorStatistic]]]: """ Generates a list of weight compression parameters based on the Weight Compression algorithm @@ -869,37 +867,34 @@ def get_weight_compression_parameters( else: group_size_values = {w_params.weight_name: self._group_size for w_params in ratio_defining_params} - # Collect statistics for the weights compression - statistics = None - if (self._data_aware_mixed_precision or self._data_aware_compression) and dataset: - weight_params = ratio_defining_params if self._backup_mode == BackupMode.NONE else all_weight_params - matmul_nodes_to_compress = [ - wp.node_with_weight - for wp in weight_params - if wp.node_with_weight.metatype in self._backend_entity.matmul_metatypes - ] - matmul_input_to_output_nodes_map = self.get_matmul_input_to_output_nodes_map( - matmul_nodes_to_compress, graph - ) - if statistic_points is None: - statistic_points = self.get_statistic_points(model, graph, matmul_input_to_output_nodes_map.keys()) - statistic_points = self._collect_statistics(dataset, graph, model, statistic_points) - statistics = self._get_statistics_for_weights_compression( - matmul_input_to_output_nodes_map, statistic_points - ) - - # Set weight compression configuration - self._set_weight_compression_config(ratio_defining_params, model, graph, statistic_points, group_size_values) - # Print statistics nncf_logger.info( self._get_bitwidth_distribution_str(all_weight_params, ratio_defining_params, skipped_weight_params) ) - # Filter all_weight_params and by excluding nodes that should remain in their original floating-point precision - all_weight_params = list(filter(lambda w_params: w_params.compression_config is not None, all_weight_params)) + return all_weight_params, ratio_defining_params, group_size_values - return all_weight_params, statistics + def collect_statistics_and_statistic_points( + self, model, graph, statistic_points, dataset, ratio_defining_params, all_weight_params + ): + if not dataset or not (self._data_aware_mixed_precision or self._data_aware_compression): + return None, statistic_points + weight_params = ratio_defining_params if self._backup_mode == BackupMode.NONE else all_weight_params + matmul_nodes_to_compress = [ + wp.node_with_weight + for wp in weight_params + if wp.node_with_weight.metatype in self._backend_entity.matmul_metatypes + ] + matmul_input_to_output_nodes_map = self.get_matmul_input_to_output_nodes_map(matmul_nodes_to_compress, graph) + if statistic_points is None: + statistic_points = self.get_statistic_points(model, graph, matmul_input_to_output_nodes_map.keys()) + statistics_aggregator = StatisticsAggregatorFactory.create(model, dataset) + statistics_aggregator.register_statistic_points(statistic_points) + statistics_aggregator.collect_statistics(model, graph) + statistic_points = statistics_aggregator.statistic_points + return self._get_statistics_for_weights_compression( + matmul_input_to_output_nodes_map, statistic_points + ), statistic_points def apply( self, @@ -911,7 +906,18 @@ def apply( self.set_backend_entity(model) # Get processed weight compression parameters ready for compression - all_weight_params, statistics = self.get_weight_compression_parameters(model, graph, statistic_points, dataset) + all_weight_params, ratio_defining_params, group_size_values = self.get_weight_compression_parameters( + model, graph + ) + # Collect statistics for the weights compression + statistics, statistic_points = self.collect_statistics_and_statistic_points( + model, graph, statistic_points, dataset, ratio_defining_params, all_weight_params + ) + # Set weight compression configuration + self._set_weight_compression_config(ratio_defining_params, model, graph, statistic_points, group_size_values) + + # Filter all_weight_params and by excluding nodes that should remain in their original floating-point precision + all_weight_params = list(filter(lambda w_params: w_params.compression_config is not None, all_weight_params)) if self._awq: model = self.awq_algo.apply(model, graph, all_weight_params, statistics, self._backend_entity) @@ -1048,26 +1054,6 @@ def get_compression_nodes_info( matmul_input_to_output_nodes_map = self.get_matmul_input_to_output_nodes_map(matmul_nodes_to_compress, graph) return nodes_to_compress, matmul_input_to_output_nodes_map - def _collect_statistics( - self, - dataset: Dataset, - graph: NNCFGraph, - model: TModel, - statistic_points: StatisticPointsContainer, - ): - """ - Creates statistics aggregator, registers all statistics specified for algorithm, and then collect them. - - :param dataset: Dataset to collect values. - :param graph: Model graph. - :param model: Model for statistics collection. - :param statistic_points: Statistics points. - """ - statistics_aggregator = StatisticsAggregatorFactory.create(model, dataset) - statistics_aggregator.register_statistic_points(statistic_points) - statistics_aggregator.collect_statistics(model, graph) - return statistics_aggregator.statistic_points - def get_statistic_points( self, model: TModel, From d593d7ab75b4fc799a6eca9c73d38de786be4be8 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Thu, 2 Oct 2025 17:44:15 +0200 Subject: [PATCH 2/2] apply_with_parameters --- .../weight_compression/algorithm.py | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/nncf/quantization/algorithms/weight_compression/algorithm.py b/src/nncf/quantization/algorithms/weight_compression/algorithm.py index 202a8d57097..70a3f1d99f4 100644 --- a/src/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/src/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -874,7 +874,7 @@ def get_weight_compression_parameters( return all_weight_params, ratio_defining_params, group_size_values - def collect_statistics_and_statistic_points( + def _collect_statistics_and_statistic_points( self, model, graph, statistic_points, dataset, ratio_defining_params, all_weight_params ): if not dataset or not (self._data_aware_mixed_precision or self._data_aware_compression): @@ -909,8 +909,28 @@ def apply( all_weight_params, ratio_defining_params, group_size_values = self.get_weight_compression_parameters( model, graph ) + return self.apply_with_parameters( + model, + graph, + dataset, + statistic_points, + all_weight_params, + ratio_defining_params, + group_size_values, + ) + + def apply_with_parameters( + self, + model, + graph, + dataset, + statistic_points, + all_weight_params, + ratio_defining_params, + group_size_values, + ): # Collect statistics for the weights compression - statistics, statistic_points = self.collect_statistics_and_statistic_points( + statistics, statistic_points = self._collect_statistics_and_statistic_points( model, graph, statistic_points, dataset, ratio_defining_params, all_weight_params ) # Set weight compression configuration