Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 54 additions & 48 deletions src/nncf/quantization/algorithms/weight_compression/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,8 +773,6 @@ def get_weight_compression_parameters(
self,
model: TModel,
graph: NNCFGraph,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
) -> tuple[list[WeightCompressionParameters], Optional[dict[str, WCTensorStatistic]]]:
"""
Generates a list of weight compression parameters based on the Weight Compression algorithm
Expand Down Expand Up @@ -869,37 +867,34 @@ def get_weight_compression_parameters(
else:
group_size_values = {w_params.weight_name: self._group_size for w_params in ratio_defining_params}

# Collect statistics for the weights compression
statistics = None
if (self._data_aware_mixed_precision or self._data_aware_compression) and dataset:
weight_params = ratio_defining_params if self._backup_mode == BackupMode.NONE else all_weight_params
matmul_nodes_to_compress = [
wp.node_with_weight
for wp in weight_params
if wp.node_with_weight.metatype in self._backend_entity.matmul_metatypes
]
matmul_input_to_output_nodes_map = self.get_matmul_input_to_output_nodes_map(
matmul_nodes_to_compress, graph
)
if statistic_points is None:
statistic_points = self.get_statistic_points(model, graph, matmul_input_to_output_nodes_map.keys())
statistic_points = self._collect_statistics(dataset, graph, model, statistic_points)
statistics = self._get_statistics_for_weights_compression(
matmul_input_to_output_nodes_map, statistic_points
)

# Set weight compression configuration
self._set_weight_compression_config(ratio_defining_params, model, graph, statistic_points, group_size_values)

# Print statistics
nncf_logger.info(
self._get_bitwidth_distribution_str(all_weight_params, ratio_defining_params, skipped_weight_params)
)

# Filter all_weight_params and by excluding nodes that should remain in their original floating-point precision
all_weight_params = list(filter(lambda w_params: w_params.compression_config is not None, all_weight_params))
return all_weight_params, ratio_defining_params, group_size_values

return all_weight_params, statistics
def _collect_statistics_and_statistic_points(
self, model, graph, statistic_points, dataset, ratio_defining_params, all_weight_params
):
if not dataset or not (self._data_aware_mixed_precision or self._data_aware_compression):
return None, statistic_points
weight_params = ratio_defining_params if self._backup_mode == BackupMode.NONE else all_weight_params
matmul_nodes_to_compress = [
wp.node_with_weight
for wp in weight_params
if wp.node_with_weight.metatype in self._backend_entity.matmul_metatypes
]
matmul_input_to_output_nodes_map = self.get_matmul_input_to_output_nodes_map(matmul_nodes_to_compress, graph)
if statistic_points is None:
statistic_points = self.get_statistic_points(model, graph, matmul_input_to_output_nodes_map.keys())
statistics_aggregator = StatisticsAggregatorFactory.create(model, dataset)
statistics_aggregator.register_statistic_points(statistic_points)
statistics_aggregator.collect_statistics(model, graph)
statistic_points = statistics_aggregator.statistic_points
return self._get_statistics_for_weights_compression(
matmul_input_to_output_nodes_map, statistic_points
), statistic_points

def apply(
self,
Expand All @@ -911,7 +906,38 @@ def apply(
self.set_backend_entity(model)

# Get processed weight compression parameters ready for compression
all_weight_params, statistics = self.get_weight_compression_parameters(model, graph, statistic_points, dataset)
all_weight_params, ratio_defining_params, group_size_values = self.get_weight_compression_parameters(
model, graph
)
Comment on lines +909 to +911
Copy link
Owner Author

@daniil-lyakhov daniil-lyakhov Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all_weight_params, ratio_defining_params, group_size_values = self._quantizer.get_weight_compression_parameters(
            model, graph
        )

quantizer doesn't need sensetivity metrics/ ratio and etc. It is the basic mixed precision: if the node is embedding node/ last node/ conv node- it is in the backup precision. Have to keep backup precision though

return self.apply_with_parameters(
Copy link
Owner Author

@daniil-lyakhov daniil-lyakhov Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._algo.apply_with_parameters(
            model,
            graph,
            dataset,
            statistic_points,
            all_weight_params,
            ratio_defining_params,
            group_size_values,
        )

algo has all required params like sencetivity metric/ ratio and etc

model,
graph,
dataset,
statistic_points,
all_weight_params,
ratio_defining_params,
group_size_values,
)

def apply_with_parameters(
self,
model,
graph,
dataset,
statistic_points,
all_weight_params,
ratio_defining_params,
group_size_values,
):
# Collect statistics for the weights compression
statistics, statistic_points = self._collect_statistics_and_statistic_points(
model, graph, statistic_points, dataset, ratio_defining_params, all_weight_params
)
# Set weight compression configuration
self._set_weight_compression_config(ratio_defining_params, model, graph, statistic_points, group_size_values)

# Filter all_weight_params and by excluding nodes that should remain in their original floating-point precision
all_weight_params = list(filter(lambda w_params: w_params.compression_config is not None, all_weight_params))

if self._awq:
model = self.awq_algo.apply(model, graph, all_weight_params, statistics, self._backend_entity)
Expand Down Expand Up @@ -1048,26 +1074,6 @@ def get_compression_nodes_info(
matmul_input_to_output_nodes_map = self.get_matmul_input_to_output_nodes_map(matmul_nodes_to_compress, graph)
return nodes_to_compress, matmul_input_to_output_nodes_map

def _collect_statistics(
self,
dataset: Dataset,
graph: NNCFGraph,
model: TModel,
statistic_points: StatisticPointsContainer,
):
"""
Creates statistics aggregator, registers all statistics specified for algorithm, and then collect them.

:param dataset: Dataset to collect values.
:param graph: Model graph.
:param model: Model for statistics collection.
:param statistic_points: Statistics points.
"""
statistics_aggregator = StatisticsAggregatorFactory.create(model, dataset)
statistics_aggregator.register_statistic_points(statistic_points)
statistics_aggregator.collect_statistics(model, graph)
return statistics_aggregator.statistic_points

def get_statistic_points(
self,
model: TModel,
Expand Down
Loading