diff --git a/src/nncf/quantization/algorithms/weight_compression/algorithm.py b/src/nncf/quantization/algorithms/weight_compression/algorithm.py index d0407a1eff4..fa25dadfbd4 100644 --- a/src/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/src/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -102,7 +102,7 @@ def get_weight_compression_configuration( ) return { - "mode": mode, + "mode": mode if isinstance(mode, nncf.CompressWeightsMode) else nncf.CompressWeightsMode(mode), "ratio": ratio or 1, "group_size": group_size, "all_layers": all_layers or False, @@ -508,9 +508,9 @@ def _get_primary_config(self, group_size: int) -> WeightCompressionConfig: def _set_weight_compression_config( self, ratio_defining_params: list[WeightCompressionParameters], + primary_precision_weight_params: list[WeightCompressionParameters], model: TModel, graph: NNCFGraph, - statistics_points: StatisticPointsContainer, group_size_values: dict[str, int], ) -> None: """ @@ -520,16 +520,8 @@ def _set_weight_compression_config( backup precisions. :param model: The model. :param graph: The model graph associated with the model. - :param statistics_points: Statistics points. :param group_size_values: A dictionary mapping weight names to their group size values. """ - if self._ratio < 1 and len(ratio_defining_params) > 0: - primary_precision_weight_params = self._mixed_precision_algo.apply( - model, graph, statistics_points, weight_params=ratio_defining_params - ) - else: - primary_precision_weight_params = ratio_defining_params - for weight_param in primary_precision_weight_params: weight_param.compression_config = self._get_primary_config(group_size_values[weight_param.weight_name]) @@ -769,10 +761,89 @@ def is_weight_compression_supported( return is_supported_dtype and not no_bit_reduction + def collect_weight_compression_statistics( + self, + model: TModel, + graph: NNCFGraph, + dataset: Dataset, + weight_params: list[WeightCompressionParameters], + statistic_points: Optional[StatisticPointsContainer] = None, + ) -> Optional[dict[str, Any]]: + """ + Collects statistics for weight compression if data-aware compression or + mixed-precision is enabled. + + :param model: Backend-specific input model. + :param graph: NNCFGraph instance. + :param dataset: Dataset for statistics collection. + :param weight_params: Weight parameters for which to collect statistics. + :param statistic_points: Optional pre-collected statistic points. + :return: A dictionary of collected statistics, or None if not applicable. + """ + statistics = None + if not (self._data_aware_mixed_precision or self._data_aware_compression) or not dataset: + return statistics, statistic_points + matmul_nodes_to_compress = [ + wp.node_with_weight + for wp in weight_params + if wp.node_with_weight.metatype in self._backend_entity.matmul_metatypes + ] + matmul_input_to_output_nodes_map = self.get_matmul_input_to_output_nodes_map(matmul_nodes_to_compress, graph) + + if statistic_points is None: + statistic_points = self.get_statistic_points(model, graph, matmul_input_to_output_nodes_map.keys()) + statistic_points = self._collect_statistics(dataset, graph, model, statistic_points) + + statistics = self._get_statistics_for_weights_compression(matmul_input_to_output_nodes_map, statistic_points) + return statistics, statistic_points + + def is_last_layer_skipped( + self, + skipped_params: list[WeightCompressionParameters], + nodes_to_compress: list[NNCFNode], + ) -> bool: + """ + Returns True if the final node in nodes_to_compress does not appear in compressed weights. + """ + if not (nodes_to_compress and skipped_params): + return False + last_node = nodes_to_compress[-1] + return any(param.node_with_weight == last_node for param in skipped_params) + + def get_skipped_weight_compression_parameters( + self, + model: TModel, + graph: NNCFGraph, + nodes_to_compress: list[NNCFNode], + ) -> list[WeightCompressionParameters]: + skipped_weight_params: list[WeightCompressionParameters] = [] + weight_names = set() + # ignored_names = self.get_ignored_node_names(graph) + + for param in nodes_to_compress: + node = param.node_with_weight + # is_target_node = should_consider_scope(node.node_name, ignored_names) + for weight_name, weight_port_id in self._backend_entity.get_weight_names_and_port_ids(node, graph): + weight_dtype = self._backend_entity.get_weight_dtype(node, weight_port_id, model, graph) + weight_shape = self._backend_entity.get_weight_shape(node, weight_port_id, graph) + reduction_axes = self._backend_entity.get_reduction_axes(node, weight_port_id, graph) + + wc_config = None + should_skip = not self.is_weight_compression_supported(weight_dtype, self._mode) + if should_skip: + skipped_weight_params.append( + WeightCompressionParameters( + weight_name, node, weight_port_id, weight_dtype, weight_shape, reduction_axes, wc_config + ) + ) + weight_names.add(weight_name) + return skipped_weight_params + def get_weight_compression_parameters( self, model: TModel, graph: NNCFGraph, + nodes_to_compress: list[NNCFNode], statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, ) -> tuple[list[WeightCompressionParameters], Optional[dict[str, WCTensorStatistic]]]: @@ -791,8 +862,6 @@ def get_weight_compression_parameters( Compression algorithm configuration, and a mapping of target node names to the collected statistics. """ - nodes_to_compress = self.get_nodes_to_compress(graph) - all_weight_params: list[WeightCompressionParameters] = [] skipped_weight_params: list[WeightCompressionParameters] = [] @@ -849,16 +918,29 @@ def get_weight_compression_parameters( all_weight_params.append(weight_params) weight_names.add(weight_name) else: - is_last_layer_skipped = is_last_layer - skipped_weight_params.append( - WeightCompressionParameters( - weight_name, node, weight_port_id, weight_dtype, weight_shape, reduction_axes, wc_config - ) + weight_params = WeightCompressionParameters( + weight_name, node, weight_port_id, weight_dtype, weight_shape, reduction_axes, None ) + all_weight_params.append(weight_params) + return all_weight_params + + def apply_wc_algos( + self, + model: TModel, + graph: NNCFGraph, + all_weight_params: list[WeightCompressionParameters], + nodes_to_compress, + statistics: dict[str, Any], + statistic_points, + dataset: Optional[Dataset] = None, + ) -> TModel: + + skipped_weight_params = self.get_skipped_weight_compression_parameters(model, graph, all_weight_params) + is_last_layer_skipped = self.is_last_layer_skipped(skipped_weight_params, nodes_to_compress) # Get subset of nodes to define compression ratio ratio_defining_params = self._get_ratio_defining_params(all_weight_params, is_last_layer_skipped) - + # Handle group size fallback modes if self._group_size_fallback_mode == GroupSizeFallbackMode.IGNORE: all_weight_params, ratio_defining_params, skipped_weight_params = self._handle_ignore_group_size_fallback( @@ -868,51 +950,23 @@ def get_weight_compression_parameters( ratio_defining_params, group_size_values = self._handle_adjust_group_size_fallback(ratio_defining_params) else: group_size_values = {w_params.weight_name: self._group_size for w_params in ratio_defining_params} - - # Collect statistics for the weights compression - statistics = None - if (self._data_aware_mixed_precision or self._data_aware_compression) and dataset: - weight_params = ratio_defining_params if self._backup_mode == BackupMode.NONE else all_weight_params - matmul_nodes_to_compress = [ - wp.node_with_weight - for wp in weight_params - if wp.node_with_weight.metatype in self._backend_entity.matmul_metatypes - ] - matmul_input_to_output_nodes_map = self.get_matmul_input_to_output_nodes_map( - matmul_nodes_to_compress, graph - ) - if statistic_points is None: - statistic_points = self.get_statistic_points(model, graph, matmul_input_to_output_nodes_map.keys()) - statistic_points = self._collect_statistics(dataset, graph, model, statistic_points) - statistics = self._get_statistics_for_weights_compression( - matmul_input_to_output_nodes_map, statistic_points - ) - + + if self._ratio < 1 and len(ratio_defining_params) > 0: + primary_precision_weight_params = self._mixed_precision_algo.apply( + model, graph, statistic_points, weight_params=ratio_defining_params + ) + else: + primary_precision_weight_params = ratio_defining_params # Set weight compression configuration - self._set_weight_compression_config(ratio_defining_params, model, graph, statistic_points, group_size_values) + self._set_weight_compression_config(ratio_defining_params, primary_precision_weight_params, model, graph, group_size_values) # Print statistics nncf_logger.info( self._get_bitwidth_distribution_str(all_weight_params, ratio_defining_params, skipped_weight_params) ) - # Filter all_weight_params and by excluding nodes that should remain in their original floating-point precision all_weight_params = list(filter(lambda w_params: w_params.compression_config is not None, all_weight_params)) - return all_weight_params, statistics - - def apply( - self, - model: TModel, - graph: NNCFGraph, - statistic_points: Optional[StatisticPointsContainer] = None, - dataset: Optional[Dataset] = None, - ) -> TModel: - self.set_backend_entity(model) - - # Get processed weight compression parameters ready for compression - all_weight_params, statistics = self.get_weight_compression_parameters(model, graph, statistic_points, dataset) - if self._awq: model = self.awq_algo.apply(model, graph, all_weight_params, statistics, self._backend_entity) # After applying AWQ we need to update statistics since AWQ alters the activations @@ -983,6 +1037,29 @@ def apply( }, algo_name="weight_compression", ) + + return transformed_model + + def apply( + self, + model: TModel, + graph: NNCFGraph, + statistic_points: Optional[StatisticPointsContainer] = None, + dataset: Optional[Dataset] = None, + ) -> TModel: + self.set_backend_entity(model) + nodes_to_compress = self.get_nodes_to_compress(graph) + + # Get processed weight compression parameters ready for compression + all_weight_params = self.get_weight_compression_parameters( + model, graph, nodes_to_compress, statistic_points, dataset + ) + statistics, statistic_points = self.collect_weight_compression_statistics( + model, graph, dataset, all_weight_params, statistic_points + ) + + transformed_model = self.apply_wc_algos(model, graph, all_weight_params, nodes_to_compress, statistics, statistic_points, dataset) + return transformed_model def _get_activation_node_and_port(self, node: NNCFNode, nncf_graph: NNCFGraph) -> tuple[NNCFNode, int]: