Skip to content
Merged
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -769,15 +769,28 @@ def is_weight_compression_supported(

return is_supported_dtype and not no_bit_reduction

def apply(
def get_weight_compression_parameters(
self,
model: TModel,
graph: NNCFGraph,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
) -> TModel:
self.set_backend_entity(model)
) -> tuple[list[WeightCompressionParameters], Optional[dict[str, WCTensorStatistic]]]:
"""
Generates a list of weight compression parameters based on the Weight Compression algorithm
configuration. Determines the appropriate quantization parameters for each node eligible for
weight compression. Also, Generates a mapping of target node names to the collected statistics
based on the provided statistic_points. If statistic_points is None, collects required
compression statistics on the given dataset.

:param model: Backend-specific input model.
:param graph: NNCFGraph instance.
:param statistic_points: Optional pre-collected statistic points.
:param dataset: Optional dataset for statistics collection.
:return: A tuple consisting of a list of weight compression parameters, based on the Weight
Compression algorithm configuration, and a mapping of target node names to the
collected statistics.
"""
nodes_to_compress = self.get_nodes_to_compress(graph)

all_weight_params: list[WeightCompressionParameters] = []
Expand All @@ -787,12 +800,13 @@ def apply(
is_last_layer_skipped = False
n = len(nodes_to_compress)
ignored_names = self.get_ignored_node_names(graph)

for i, node in enumerate(nodes_to_compress):
is_target_node = should_consider_scope(node.node_name, ignored_names)
for weight_name, weight_port_id in self._backend_entity.get_weight_names_and_port_ids(node, graph):
is_last_layer = i == n - 1
if weight_name in weight_names:
# If the last layer has shared weights then skiped
# If the last layer has shared weights then skip it
# to avoid processing the same weight more than once
is_last_layer_skipped = is_last_layer
continue
Expand Down Expand Up @@ -828,6 +842,7 @@ def apply(
)
if self.is_weight_compression_supported(weight_dtype, mode):
wc_config = WeightCompressionConfig(mode=mode)

weight_params = WeightCompressionParameters(
weight_name, node, weight_port_id, weight_dtype, weight_shape, reduction_axes, wc_config
)
Expand Down Expand Up @@ -884,6 +899,20 @@ def apply(
# Filter all_weight_params and by excluding nodes that should remain in their original floating-point precision
all_weight_params = list(filter(lambda w_params: w_params.compression_config is not None, all_weight_params))

return all_weight_params, statistics

def apply(
self,
model: TModel,
graph: NNCFGraph,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
) -> TModel:
self.set_backend_entity(model)

# Get processed weight compression parameters ready for compression
all_weight_params, statistics = self.get_weight_compression_parameters(model, graph, statistic_points, dataset)

if self._awq:
model = self.awq_algo.apply(model, graph, all_weight_params, statistics, self._backend_entity)
# After applying AWQ we need to update statistics since AWQ alters the activations
Expand Down
Loading