diff --git a/src/nncf/quantization/algorithms/weight_compression/algorithm.py b/src/nncf/quantization/algorithms/weight_compression/algorithm.py index 1fb2f343f72..d0407a1eff4 100644 --- a/src/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/src/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -769,15 +769,28 @@ def is_weight_compression_supported( return is_supported_dtype and not no_bit_reduction - def apply( + def get_weight_compression_parameters( self, model: TModel, graph: NNCFGraph, statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, - ) -> TModel: - self.set_backend_entity(model) + ) -> tuple[list[WeightCompressionParameters], Optional[dict[str, WCTensorStatistic]]]: + """ + Generates a list of weight compression parameters based on the Weight Compression algorithm + configuration. Determines the appropriate quantization parameters for each node eligible for + weight compression. Also, Generates a mapping of target node names to the collected statistics + based on the provided statistic_points. If statistic_points is None, collects required + compression statistics on the given dataset. + :param model: Backend-specific input model. + :param graph: NNCFGraph instance. + :param statistic_points: Optional pre-collected statistic points. + :param dataset: Optional dataset for statistics collection. + :return: A tuple consisting of a list of weight compression parameters, based on the Weight + Compression algorithm configuration, and a mapping of target node names to the + collected statistics. + """ nodes_to_compress = self.get_nodes_to_compress(graph) all_weight_params: list[WeightCompressionParameters] = [] @@ -787,12 +800,13 @@ def apply( is_last_layer_skipped = False n = len(nodes_to_compress) ignored_names = self.get_ignored_node_names(graph) + for i, node in enumerate(nodes_to_compress): is_target_node = should_consider_scope(node.node_name, ignored_names) for weight_name, weight_port_id in self._backend_entity.get_weight_names_and_port_ids(node, graph): is_last_layer = i == n - 1 if weight_name in weight_names: - # If the last layer has shared weights then skiped + # If the last layer has shared weights then skip it # to avoid processing the same weight more than once is_last_layer_skipped = is_last_layer continue @@ -828,6 +842,7 @@ def apply( ) if self.is_weight_compression_supported(weight_dtype, mode): wc_config = WeightCompressionConfig(mode=mode) + weight_params = WeightCompressionParameters( weight_name, node, weight_port_id, weight_dtype, weight_shape, reduction_axes, wc_config ) @@ -884,6 +899,20 @@ def apply( # Filter all_weight_params and by excluding nodes that should remain in their original floating-point precision all_weight_params = list(filter(lambda w_params: w_params.compression_config is not None, all_weight_params)) + return all_weight_params, statistics + + def apply( + self, + model: TModel, + graph: NNCFGraph, + statistic_points: Optional[StatisticPointsContainer] = None, + dataset: Optional[Dataset] = None, + ) -> TModel: + self.set_backend_entity(model) + + # Get processed weight compression parameters ready for compression + all_weight_params, statistics = self.get_weight_compression_parameters(model, graph, statistic_points, dataset) + if self._awq: model = self.awq_algo.apply(model, graph, all_weight_params, statistics, self._backend_entity) # After applying AWQ we need to update statistics since AWQ alters the activations