Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 131 additions & 54 deletions src/nncf/quantization/algorithms/weight_compression/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def get_weight_compression_configuration(
)

return {
"mode": mode,
"mode": mode if isinstance(mode, nncf.CompressWeightsMode) else nncf.CompressWeightsMode(mode),
"ratio": ratio or 1,
"group_size": group_size,
"all_layers": all_layers or False,
Expand Down Expand Up @@ -508,9 +508,9 @@ def _get_primary_config(self, group_size: int) -> WeightCompressionConfig:
def _set_weight_compression_config(
self,
ratio_defining_params: list[WeightCompressionParameters],
primary_precision_weight_params: list[WeightCompressionParameters],
model: TModel,
graph: NNCFGraph,
statistics_points: StatisticPointsContainer,
group_size_values: dict[str, int],
) -> None:
"""
Expand All @@ -520,16 +520,8 @@ def _set_weight_compression_config(
backup precisions.
:param model: The model.
:param graph: The model graph associated with the model.
:param statistics_points: Statistics points.
:param group_size_values: A dictionary mapping weight names to their group size values.
"""
if self._ratio < 1 and len(ratio_defining_params) > 0:
primary_precision_weight_params = self._mixed_precision_algo.apply(
model, graph, statistics_points, weight_params=ratio_defining_params
)
else:
primary_precision_weight_params = ratio_defining_params

for weight_param in primary_precision_weight_params:
weight_param.compression_config = self._get_primary_config(group_size_values[weight_param.weight_name])

Expand Down Expand Up @@ -769,10 +761,89 @@ def is_weight_compression_supported(

return is_supported_dtype and not no_bit_reduction

def collect_weight_compression_statistics(
self,
model: TModel,
graph: NNCFGraph,
dataset: Dataset,
weight_params: list[WeightCompressionParameters],
statistic_points: Optional[StatisticPointsContainer] = None,
) -> Optional[dict[str, Any]]:
"""
Collects statistics for weight compression if data-aware compression or
mixed-precision is enabled.

:param model: Backend-specific input model.
:param graph: NNCFGraph instance.
:param dataset: Dataset for statistics collection.
:param weight_params: Weight parameters for which to collect statistics.
:param statistic_points: Optional pre-collected statistic points.
:return: A dictionary of collected statistics, or None if not applicable.
"""
statistics = None
if not (self._data_aware_mixed_precision or self._data_aware_compression) or not dataset:
return statistics, statistic_points
matmul_nodes_to_compress = [
wp.node_with_weight
for wp in weight_params
if wp.node_with_weight.metatype in self._backend_entity.matmul_metatypes
]
matmul_input_to_output_nodes_map = self.get_matmul_input_to_output_nodes_map(matmul_nodes_to_compress, graph)

if statistic_points is None:
statistic_points = self.get_statistic_points(model, graph, matmul_input_to_output_nodes_map.keys())
statistic_points = self._collect_statistics(dataset, graph, model, statistic_points)

statistics = self._get_statistics_for_weights_compression(matmul_input_to_output_nodes_map, statistic_points)
return statistics, statistic_points

def is_last_layer_skipped(
self,
skipped_params: list[WeightCompressionParameters],
nodes_to_compress: list[NNCFNode],
) -> bool:
"""
Returns True if the final node in nodes_to_compress does not appear in compressed weights.
"""
if not (nodes_to_compress and skipped_params):
return False
last_node = nodes_to_compress[-1]
return any(param.node_with_weight == last_node for param in skipped_params)

def get_skipped_weight_compression_parameters(
self,
model: TModel,
graph: NNCFGraph,
nodes_to_compress: list[NNCFNode],
) -> list[WeightCompressionParameters]:
skipped_weight_params: list[WeightCompressionParameters] = []
weight_names = set()
# ignored_names = self.get_ignored_node_names(graph)

for param in nodes_to_compress:
node = param.node_with_weight
# is_target_node = should_consider_scope(node.node_name, ignored_names)
for weight_name, weight_port_id in self._backend_entity.get_weight_names_and_port_ids(node, graph):
weight_dtype = self._backend_entity.get_weight_dtype(node, weight_port_id, model, graph)
weight_shape = self._backend_entity.get_weight_shape(node, weight_port_id, graph)
reduction_axes = self._backend_entity.get_reduction_axes(node, weight_port_id, graph)

wc_config = None
should_skip = not self.is_weight_compression_supported(weight_dtype, self._mode)
if should_skip:
skipped_weight_params.append(
WeightCompressionParameters(
weight_name, node, weight_port_id, weight_dtype, weight_shape, reduction_axes, wc_config
)
)
weight_names.add(weight_name)
return skipped_weight_params

def get_weight_compression_parameters(
self,
model: TModel,
graph: NNCFGraph,
nodes_to_compress: list[NNCFNode],
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
) -> tuple[list[WeightCompressionParameters], Optional[dict[str, WCTensorStatistic]]]:
Expand All @@ -791,8 +862,6 @@ def get_weight_compression_parameters(
Compression algorithm configuration, and a mapping of target node names to the
collected statistics.
"""
nodes_to_compress = self.get_nodes_to_compress(graph)

all_weight_params: list[WeightCompressionParameters] = []
skipped_weight_params: list[WeightCompressionParameters] = []

Expand Down Expand Up @@ -849,16 +918,29 @@ def get_weight_compression_parameters(
all_weight_params.append(weight_params)
weight_names.add(weight_name)
else:
is_last_layer_skipped = is_last_layer
skipped_weight_params.append(
WeightCompressionParameters(
weight_name, node, weight_port_id, weight_dtype, weight_shape, reduction_axes, wc_config
)
weight_params = WeightCompressionParameters(
weight_name, node, weight_port_id, weight_dtype, weight_shape, reduction_axes, None
)
all_weight_params.append(weight_params)
return all_weight_params

def apply_wc_algos(
self,
model: TModel,
graph: NNCFGraph,
all_weight_params: list[WeightCompressionParameters],
nodes_to_compress,
statistics: dict[str, Any],
statistic_points,
dataset: Optional[Dataset] = None,
) -> TModel:

skipped_weight_params = self.get_skipped_weight_compression_parameters(model, graph, all_weight_params)
is_last_layer_skipped = self.is_last_layer_skipped(skipped_weight_params, nodes_to_compress)

# Get subset of nodes to define compression ratio
ratio_defining_params = self._get_ratio_defining_params(all_weight_params, is_last_layer_skipped)

# Handle group size fallback modes
if self._group_size_fallback_mode == GroupSizeFallbackMode.IGNORE:
all_weight_params, ratio_defining_params, skipped_weight_params = self._handle_ignore_group_size_fallback(
Expand All @@ -868,51 +950,23 @@ def get_weight_compression_parameters(
ratio_defining_params, group_size_values = self._handle_adjust_group_size_fallback(ratio_defining_params)
else:
group_size_values = {w_params.weight_name: self._group_size for w_params in ratio_defining_params}

# Collect statistics for the weights compression
statistics = None
if (self._data_aware_mixed_precision or self._data_aware_compression) and dataset:
weight_params = ratio_defining_params if self._backup_mode == BackupMode.NONE else all_weight_params
matmul_nodes_to_compress = [
wp.node_with_weight
for wp in weight_params
if wp.node_with_weight.metatype in self._backend_entity.matmul_metatypes
]
matmul_input_to_output_nodes_map = self.get_matmul_input_to_output_nodes_map(
matmul_nodes_to_compress, graph
)
if statistic_points is None:
statistic_points = self.get_statistic_points(model, graph, matmul_input_to_output_nodes_map.keys())
statistic_points = self._collect_statistics(dataset, graph, model, statistic_points)
statistics = self._get_statistics_for_weights_compression(
matmul_input_to_output_nodes_map, statistic_points
)


if self._ratio < 1 and len(ratio_defining_params) > 0:
primary_precision_weight_params = self._mixed_precision_algo.apply(
model, graph, statistic_points, weight_params=ratio_defining_params
)
else:
primary_precision_weight_params = ratio_defining_params
# Set weight compression configuration
self._set_weight_compression_config(ratio_defining_params, model, graph, statistic_points, group_size_values)
self._set_weight_compression_config(ratio_defining_params, primary_precision_weight_params, model, graph, group_size_values)

# Print statistics
nncf_logger.info(
self._get_bitwidth_distribution_str(all_weight_params, ratio_defining_params, skipped_weight_params)
)

# Filter all_weight_params and by excluding nodes that should remain in their original floating-point precision
all_weight_params = list(filter(lambda w_params: w_params.compression_config is not None, all_weight_params))

return all_weight_params, statistics

def apply(
self,
model: TModel,
graph: NNCFGraph,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
) -> TModel:
self.set_backend_entity(model)

# Get processed weight compression parameters ready for compression
all_weight_params, statistics = self.get_weight_compression_parameters(model, graph, statistic_points, dataset)

if self._awq:
model = self.awq_algo.apply(model, graph, all_weight_params, statistics, self._backend_entity)
# After applying AWQ we need to update statistics since AWQ alters the activations
Expand Down Expand Up @@ -983,6 +1037,29 @@ def apply(
},
algo_name="weight_compression",
)

return transformed_model

def apply(
self,
model: TModel,
graph: NNCFGraph,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
) -> TModel:
self.set_backend_entity(model)
nodes_to_compress = self.get_nodes_to_compress(graph)

# Get processed weight compression parameters ready for compression
all_weight_params = self.get_weight_compression_parameters(
model, graph, nodes_to_compress, statistic_points, dataset
)
statistics, statistic_points = self.collect_weight_compression_statistics(
model, graph, dataset, all_weight_params, statistic_points
)

transformed_model = self.apply_wc_algos(model, graph, all_weight_params, nodes_to_compress, statistics, statistic_points, dataset)

return transformed_model

def _get_activation_node_and_port(self, node: NNCFNode, nncf_graph: NNCFGraph) -> tuple[NNCFNode, int]:
Expand Down
Loading