Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/nncf/common/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,3 +810,18 @@ def find_matching_subgraphs(self, patterns: GraphPattern, strict: bool = True) -
subgraph_list.append(self.get_node_by_key(node_key))
output.append(subgraph_list)
return output


def get_node_names_matching_graph_pattern(nncf_graph: NNCFGraph, graph_pattern: GraphPattern) -> set[str]:
"""
Returns the names of nodes in the given NNCFGraph that match the specified graph pattern.

:param nncf_graph: An instance of NNCFGraph to search for matching subgraphs.
:param graph_pattern: A GraphPattern instance used to identify matching subgraphs.
:return: A set of node names from the NNCFGraph that match the given graph pattern.
"""
nncf_node_names = set()
for subgraph in nncf_graph.find_matching_subgraphs(graph_pattern, strict=False):
for nncf_node in subgraph:
nncf_node_names.add(nncf_node.node_name)
return nncf_node_names
21 changes: 2 additions & 19 deletions src/nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from nncf.common.factory import ModelTransformerFactory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.graph import get_node_names_matching_graph_pattern
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.patterns import GraphPattern
from nncf.common.graph.patterns.manager import PatternsManager
Expand Down Expand Up @@ -585,32 +586,14 @@ def _get_ignored_names(
user_ignored_names = get_ignored_node_names_from_ignored_scope(
self._ignored_scope, nncf_graph, strict=self._ignored_scope.validate
)
autogenerated_ignored_names = self._get_ignored_names_by_ignored_patterns(
inference_nncf_graph, ignored_patterns
)
autogenerated_ignored_names = get_node_names_matching_graph_pattern(inference_nncf_graph, ignored_patterns)
autogenerated_ignored_names |= self._backend_entity.get_ignored_names_by_layer_attributes(inference_nncf_graph)
autogenerated_ignored_names |= self._get_ignored_names_by_algorithm(inference_nncf_graph)
ignored_names = {name: IgnoreReason.AUTOGENERATED for name in autogenerated_ignored_names}
# User ignored scope has higher priority
ignored_names.update({name: IgnoreReason.USER_REQUESTED for name in user_ignored_names})
return ignored_names

def _get_ignored_names_by_ignored_patterns(
self, inference_nncf_graph: NNCFGraph, ignored_patterns: GraphPattern
) -> set[str]:
"""
Returns node names matched ignored_patterns.

:param nncf_graph: Inference graph without constant flows.
:param ignored_patterns: Ignored patterns.
:return: IgnoredScope with all node names matched ignored_patterns.
"""
nncf_node_names = set()
for subgraph in inference_nncf_graph.find_matching_subgraphs(ignored_patterns, strict=False):
for nncf_node in subgraph:
nncf_node_names.add(nncf_node.node_name)
return nncf_node_names

def _get_ignored_names_by_algorithm(self, inference_nncf_graph: NNCFGraph) -> set[str]:
"""
Returns node names for ignored_algorithms matched `quantization`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from nncf.common.factory import StatisticsAggregatorFactory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.graph import get_node_names_matching_graph_pattern
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.logging import nncf_logger
from nncf.common.logging.track_progress import track
Expand Down Expand Up @@ -383,6 +384,12 @@ def get_nodes_to_compress(self, nncf_graph: NNCFGraph) -> list[NNCFNode]:
ignored_names = get_ignored_node_names_from_ignored_scope(
self._ignored_scope, nncf_graph, strict=self._ignored_scope.validate
)

autogenerated_ignored_names = get_node_names_matching_graph_pattern(
nncf_graph, self._backend_entity.get_ignored_patterns()
)
ignored_names = ignored_names.union(autogenerated_ignored_names)

for node in nncf_graph.topological_sort():
is_node_with_weights = self._backend_entity.is_node_with_weights(node, nncf_graph)
is_within_scope = should_consider_scope(node.node_name, ignored_names)
Expand Down
10 changes: 10 additions & 0 deletions src/nncf/quantization/algorithms/weight_compression/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.patterns.patterns import GraphPattern
from nncf.common.graph.transformations.commands import TargetPoint
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase
Expand Down Expand Up @@ -257,6 +258,15 @@ def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) ->
:return: Backend-specific callable to filter statistic containers according to its statistic point.
"""

@staticmethod
@abstractmethod
def get_ignored_patterns() -> GraphPattern:
"""
Return backend-specific ignored patterns.

:return: backend-specific ignored patterns.
"""


class AWQAlgoBackend(WeightCompressionAlgoBackend):
@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.patterns.patterns import GraphPattern
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.utils import get_reduction_axes
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
Expand All @@ -43,6 +44,7 @@
from nncf.onnx.graph.onnx_helper import pack_4_bits
from nncf.onnx.graph.onnx_helper import pack_int4_to_uint8
from nncf.onnx.graph.transformations.commands import ONNXTargetPoint
from nncf.onnx.quantization.ignored_patterns import create_rope
from nncf.parameters import CompressionFormat
from nncf.parameters import CompressWeightsMode
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
Expand Down Expand Up @@ -459,3 +461,7 @@ def _replace_matmul_with_matmulnbits(
del self.name_to_node_map[original_matmul.name]
# Update the node mapping
self.name_to_node_map[matmul_n_bits.name] = matmul_n_bits

@staticmethod
def get_ignored_patterns() -> GraphPattern:
return create_rope()
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.patterns.patterns import GraphPattern
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.utils import get_reduction_axes
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
Expand All @@ -40,6 +41,7 @@
from nncf.openvino.graph.transformations.commands import OVTargetPoint
from nncf.openvino.optimized_functions import clear_ov_model_cache
from nncf.openvino.optimized_functions.models import OV_MODEL_CACHE
from nncf.openvino.quantization.ignored_patterns import create_rope
from nncf.openvino.rt_info import dump_parameters
from nncf.openvino.statistics.collectors import OVMaxVarianceReducer
from nncf.openvino.statistics.collectors import OVMeanAbsMaxReducer
Expand Down Expand Up @@ -367,6 +369,10 @@ def filter_func(point: StatisticPoint) -> bool:

return filter_func

@staticmethod
def get_ignored_patterns() -> GraphPattern:
return create_rope()


class OVTensorWeightCompressionAlgoBackend(OVWeightCompressionAlgoBackend):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from nncf.torch.model_graph_manager import split_const_name
from nncf.torch.model_transformer import PTModelTransformer
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.ignored_patterns import create_rope
from nncf.torch.quantization.layers import QUANTIZATION_MODULES
from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor
from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor
Expand Down Expand Up @@ -521,6 +522,10 @@ def transform_model(

return transformed_model

@staticmethod
def get_ignored_patterns() -> GraphPattern:
return create_rope()


class PTAWQAlgoAlgoBackend(AWQAlgoBackend, PTWeightCompressionAlgoBackend):
@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.patterns.patterns import GraphPattern
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
Expand Down Expand Up @@ -53,6 +54,7 @@
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.model_graph_manager import get_const_node
from nncf.torch.model_graph_manager import get_weight_tensor_port_ids
from nncf.torch.quantization.ignored_patterns import create_rope
from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor
from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor
from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor
Expand Down Expand Up @@ -281,6 +283,10 @@ def transform_model(

return transformed_model

@staticmethod
def get_ignored_patterns() -> GraphPattern:
return create_rope()


class FXMixedPrecisionAlgoBackend(MixedPrecisionAlgoBackend, FXWeightCompressionAlgoBackend):
@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def cast_to(x: TTensor, dtype: TensorDataType) -> TTensor:
def get_matmul_model() -> TModel:
"""Returns a backend model for test_data_based_criterion."""

@abstractmethod
def get_RoPE_model() -> TModel:
"""Returns a backend model for test_rope_weight_compression."""

@pytest.mark.parametrize(
("mode", "ref_act_score", "ref_score"),
(
Expand Down Expand Up @@ -310,6 +314,24 @@ def test_awq_with_ignored_scope(self):
int4_num_nodes = self.get_num_int4_nodes(compressed_model)
assert int4_num_nodes == int4_ref_num_compressed

def test_rope_weight_compression(self):
model = self.get_RoPE_model()
sz = 8
n_samples = 10

dataset = Dataset([self.to_tensor(np.ones([1, i + 1, sz], dtype=np.float32)) for i in range(n_samples)])
compressed_model = compress_weights(
model,
mode=CompressWeightsMode.INT4_SYM,
ratio=1.0,
group_size=-1,
dataset=dataset,
)

int4_ref_num_compressed = 0
int4_num_nodes = self.get_num_int4_nodes(compressed_model)
assert int4_num_nodes == int4_ref_num_compressed

@staticmethod
@abstractmethod
def get_reference_for_test_awq_scale_reference() -> dict[str, Tensor]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from tests.openvino.native.models import MatMul
from tests.openvino.native.models import ModelNamedConsts
from tests.openvino.native.models import OVReferenceModel
from tests.openvino.native.models import RoPEModel
from tests.openvino.native.models import SequentialMatmulModel
from tests.openvino.native.models import WeightsModel
from tests.openvino.native.quantization.test_fq_params_calculation import REFERENCE_SCALES_DIR
Expand Down Expand Up @@ -1533,6 +1534,10 @@ class TestOVTemplateWeightCompression(TemplateWeightCompression):
def get_matmul_model() -> ov.Model:
return IdentityMatmul().ov_model

@staticmethod
def get_RoPE_model() -> ov.Model:
return RoPEModel().ov_model

@staticmethod
def get_sequential_matmul_model() -> ov.Model:
return SequentialMatmulModel().ov_model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from nncf.torch.quantization.quantize_functions import pack_uint4
from nncf.torch.quantization.quantize_functions import unpack_int4
from nncf.torch.quantization.quantize_functions import unpack_uint4
from tests.cross_fw.test_templates.helpers import RoPEModel
from tests.cross_fw.test_templates.template_test_weights_compression import TemplateWeightCompression
from tests.torch.test_models.synthetic import ShortTransformer
from tests.torch.test_tensor import cast_to
Expand Down Expand Up @@ -449,6 +450,10 @@ class TestPTTemplateWeightCompression(TemplateWeightCompression):
def get_matmul_model() -> torch.nn.Module:
return MatMulModel(255 * torch.eye(3, dtype=torch.float32))

@staticmethod
def get_RoPE_model() -> torch.nn.Module:
return RoPEModel()

@staticmethod
def get_sequential_matmul_model() -> torch.nn.Module:
return SequentialMatmulModel()
Expand Down
8 changes: 8 additions & 0 deletions tests/torch2/fx/test_compress_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nncf.tensor import Tensor
from nncf.tensor import TensorDataType
from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor
from tests.cross_fw.test_templates.helpers import RoPEModel
from tests.cross_fw.test_templates.template_test_weights_compression import TemplateWeightCompression
from tests.torch.test_models.synthetic import ShortTransformer
from tests.torch.test_tensor import cast_to
Expand Down Expand Up @@ -319,6 +320,13 @@ def get_matmul_model() -> torch.fx.GraphModule:
exported_model = get_torch_fx_model(model, ex_input)
return exported_model

@staticmethod
def get_RoPE_model() -> torch.fx.GraphModule:
model = RoPEModel()
ex_input = torch.ones(RoPEModel.INPUT_SIZE, dtype=torch.float32)
exported_model = get_torch_fx_model(model, ex_input)
return exported_model

@staticmethod
def get_sequential_matmul_model() -> torch.fx.GraphModule:
model = SequentialMatmulModel()
Expand Down