Skip to content

Commit d9fc39a

Browse files
[WeightCompression] Ignored patterns introduction (#3562)
### Changes Ignored patterns are being introduced to the WeightCompression algorithm ### Reason for changes ROPE pattern breaks the logic of the WeightCompression algorithm, required to be automatically ignored ### Related tickets huggingface/optimum-intel#1295 164548 ### Tests test_rope_weight_compression --------- Co-authored-by: Alexander Dokuchaev <alexander.dokuchaev@intel.com>
1 parent 1bebabe commit d9fc39a

File tree

12 files changed

+97
-19
lines changed

12 files changed

+97
-19
lines changed

src/nncf/common/graph/graph.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -810,3 +810,18 @@ def find_matching_subgraphs(self, patterns: GraphPattern, strict: bool = True) -
810810
subgraph_list.append(self.get_node_by_key(node_key))
811811
output.append(subgraph_list)
812812
return output
813+
814+
815+
def get_node_names_matching_graph_pattern(nncf_graph: NNCFGraph, graph_pattern: GraphPattern) -> set[str]:
816+
"""
817+
Returns the names of nodes in the given NNCFGraph that match the specified graph pattern.
818+
819+
:param nncf_graph: An instance of NNCFGraph to search for matching subgraphs.
820+
:param graph_pattern: A GraphPattern instance used to identify matching subgraphs.
821+
:return: A set of node names from the NNCFGraph that match the given graph pattern.
822+
"""
823+
nncf_node_names = set()
824+
for subgraph in nncf_graph.find_matching_subgraphs(graph_pattern, strict=False):
825+
for nncf_node in subgraph:
826+
nncf_node_names.add(nncf_node.node_name)
827+
return nncf_node_names

src/nncf/quantization/algorithms/min_max/algorithm.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from nncf.common.factory import ModelTransformerFactory
2323
from nncf.common.graph.graph import NNCFGraph
2424
from nncf.common.graph.graph import NNCFNode
25+
from nncf.common.graph.graph import get_node_names_matching_graph_pattern
2526
from nncf.common.graph.operator_metatypes import OperatorMetatype
2627
from nncf.common.graph.patterns import GraphPattern
2728
from nncf.common.graph.patterns.manager import PatternsManager
@@ -585,32 +586,14 @@ def _get_ignored_names(
585586
user_ignored_names = get_ignored_node_names_from_ignored_scope(
586587
self._ignored_scope, nncf_graph, strict=self._ignored_scope.validate
587588
)
588-
autogenerated_ignored_names = self._get_ignored_names_by_ignored_patterns(
589-
inference_nncf_graph, ignored_patterns
590-
)
589+
autogenerated_ignored_names = get_node_names_matching_graph_pattern(inference_nncf_graph, ignored_patterns)
591590
autogenerated_ignored_names |= self._backend_entity.get_ignored_names_by_layer_attributes(inference_nncf_graph)
592591
autogenerated_ignored_names |= self._get_ignored_names_by_algorithm(inference_nncf_graph)
593592
ignored_names = {name: IgnoreReason.AUTOGENERATED for name in autogenerated_ignored_names}
594593
# User ignored scope has higher priority
595594
ignored_names.update({name: IgnoreReason.USER_REQUESTED for name in user_ignored_names})
596595
return ignored_names
597596

598-
def _get_ignored_names_by_ignored_patterns(
599-
self, inference_nncf_graph: NNCFGraph, ignored_patterns: GraphPattern
600-
) -> set[str]:
601-
"""
602-
Returns node names matched ignored_patterns.
603-
604-
:param nncf_graph: Inference graph without constant flows.
605-
:param ignored_patterns: Ignored patterns.
606-
:return: IgnoredScope with all node names matched ignored_patterns.
607-
"""
608-
nncf_node_names = set()
609-
for subgraph in inference_nncf_graph.find_matching_subgraphs(ignored_patterns, strict=False):
610-
for nncf_node in subgraph:
611-
nncf_node_names.add(nncf_node.node_name)
612-
return nncf_node_names
613-
614597
def _get_ignored_names_by_algorithm(self, inference_nncf_graph: NNCFGraph) -> set[str]:
615598
"""
616599
Returns node names for ignored_algorithms matched `quantization`.

src/nncf/quantization/algorithms/weight_compression/algorithm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from nncf.common.factory import StatisticsAggregatorFactory
2121
from nncf.common.graph.graph import NNCFGraph
2222
from nncf.common.graph.graph import NNCFNode
23+
from nncf.common.graph.graph import get_node_names_matching_graph_pattern
2324
from nncf.common.graph.transformations.commands import TargetType
2425
from nncf.common.logging import nncf_logger
2526
from nncf.common.logging.track_progress import track
@@ -419,6 +420,12 @@ def get_nodes_to_compress(self, nncf_graph: NNCFGraph) -> list[NNCFNode]:
419420
ignored_names = get_ignored_node_names_from_ignored_scope(
420421
self._ignored_scope, nncf_graph, strict=self._ignored_scope.validate
421422
)
423+
424+
autogenerated_ignored_names = get_node_names_matching_graph_pattern(
425+
nncf_graph, self._backend_entity.get_ignored_patterns()
426+
)
427+
ignored_names = ignored_names.union(autogenerated_ignored_names)
428+
422429
for node in nncf_graph.topological_sort():
423430
is_node_with_weights = self._backend_entity.is_node_with_weights(node, nncf_graph)
424431
is_within_scope = should_consider_scope(node.node_name, ignored_names)

src/nncf/quantization/algorithms/weight_compression/backend.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from nncf.common.graph import NNCFGraph
1717
from nncf.common.graph import NNCFNode
1818
from nncf.common.graph.operator_metatypes import OperatorMetatype
19+
from nncf.common.graph.patterns.patterns import GraphPattern
1920
from nncf.common.graph.transformations.commands import TargetPoint
2021
from nncf.common.graph.transformations.commands import TargetType
2122
from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase
@@ -256,6 +257,15 @@ def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) ->
256257
:return: Backend-specific callable to filter statistic containers according to its statistic point.
257258
"""
258259

260+
@staticmethod
261+
@abstractmethod
262+
def get_ignored_patterns() -> GraphPattern:
263+
"""
264+
Return backend-specific ignored patterns.
265+
266+
:return: backend-specific ignored patterns.
267+
"""
268+
259269

260270
class AWQAlgoBackend(WeightCompressionAlgoBackend):
261271
@staticmethod

src/nncf/quantization/algorithms/weight_compression/onnx_backend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from nncf.common.graph import NNCFGraph
2222
from nncf.common.graph import NNCFNode
2323
from nncf.common.graph.operator_metatypes import OperatorMetatype
24+
from nncf.common.graph.patterns.patterns import GraphPattern
2425
from nncf.common.graph.transformations.commands import TargetType
2526
from nncf.common.graph.utils import get_reduction_axes
2627
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
@@ -43,6 +44,7 @@
4344
from nncf.onnx.graph.onnx_helper import pack_4_bits
4445
from nncf.onnx.graph.onnx_helper import pack_int4_to_uint8
4546
from nncf.onnx.graph.transformations.commands import ONNXTargetPoint
47+
from nncf.onnx.quantization.ignored_patterns import create_rope
4648
from nncf.parameters import CompressionFormat
4749
from nncf.parameters import CompressWeightsMode
4850
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
@@ -458,3 +460,7 @@ def _replace_matmul_with_matmulnbits(
458460
del self.name_to_node_map[original_matmul.name]
459461
# Update the node mapping
460462
self.name_to_node_map[matmul_n_bits.name] = matmul_n_bits
463+
464+
@staticmethod
465+
def get_ignored_patterns() -> GraphPattern:
466+
return create_rope()

src/nncf/quantization/algorithms/weight_compression/openvino_backend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from nncf.common.graph import NNCFGraph
1818
from nncf.common.graph import NNCFNode
1919
from nncf.common.graph.operator_metatypes import OperatorMetatype
20+
from nncf.common.graph.patterns.patterns import GraphPattern
2021
from nncf.common.graph.transformations.commands import TargetType
2122
from nncf.common.graph.utils import get_reduction_axes
2223
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
@@ -41,6 +42,7 @@
4142
from nncf.openvino.graph.transformations.commands import OVTargetPoint
4243
from nncf.openvino.optimized_functions import clear_ov_model_cache
4344
from nncf.openvino.optimized_functions.models import OV_MODEL_CACHE
45+
from nncf.openvino.quantization.ignored_patterns import create_rope
4446
from nncf.openvino.rt_info import dump_parameters
4547
from nncf.openvino.statistics.collectors import OVMaxVarianceReducer
4648
from nncf.openvino.statistics.collectors import OVMeanAbsMaxReducer
@@ -367,6 +369,10 @@ def filter_func(point: StatisticPoint) -> bool:
367369

368370
return filter_func
369371

372+
@staticmethod
373+
def get_ignored_patterns() -> GraphPattern:
374+
return create_rope()
375+
370376

371377
class OVTensorWeightCompressionAlgoBackend(OVWeightCompressionAlgoBackend):
372378
"""

src/nncf/quantization/algorithms/weight_compression/torch_backend.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
from nncf.torch.model_graph_manager import split_const_name
6767
from nncf.torch.model_transformer import PTModelTransformer
6868
from nncf.torch.nncf_network import NNCFNetwork
69+
from nncf.torch.quantization.ignored_patterns import create_rope
6970
from nncf.torch.quantization.layers import QUANTIZATION_MODULES
7071
from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor
7172
from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor
@@ -510,6 +511,10 @@ def transform_model(
510511

511512
return transformed_model
512513

514+
@staticmethod
515+
def get_ignored_patterns() -> GraphPattern:
516+
return create_rope()
517+
513518

514519
class PTAWQAlgoAlgoBackend(AWQAlgoBackend, PTWeightCompressionAlgoBackend):
515520
@staticmethod

src/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from nncf.common.graph.graph import NNCFGraph
2121
from nncf.common.graph.graph import NNCFNode
2222
from nncf.common.graph.operator_metatypes import OperatorMetatype
23+
from nncf.common.graph.patterns.patterns import GraphPattern
2324
from nncf.common.graph.transformations.commands import TargetType
2425
from nncf.common.graph.transformations.layout import TransformationLayout
2526
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
@@ -53,6 +54,7 @@
5354
from nncf.torch.graph.transformations.commands import PTTargetPoint
5455
from nncf.torch.model_graph_manager import get_const_node
5556
from nncf.torch.model_graph_manager import get_weight_tensor_port_ids
57+
from nncf.torch.quantization.ignored_patterns import create_rope
5658
from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor
5759
from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor
5860
from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor
@@ -273,6 +275,10 @@ def transform_model(
273275

274276
return transformed_model
275277

278+
@staticmethod
279+
def get_ignored_patterns() -> GraphPattern:
280+
return create_rope()
281+
276282

277283
class FXMixedPrecisionAlgoBackend(MixedPrecisionAlgoBackend, FXWeightCompressionAlgoBackend):
278284
@staticmethod

tests/cross_fw/test_templates/template_test_weights_compression.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ def cast_to(x: TTensor, dtype: TensorDataType) -> TTensor:
7979
def get_matmul_model() -> TModel:
8080
"""Returns a backend model for test_data_based_criterion."""
8181

82+
@abstractmethod
83+
def get_RoPE_model() -> TModel:
84+
"""Returns a backend model for test_rope_weight_compression."""
85+
8286
@pytest.mark.parametrize(
8387
("mode", "ref_act_score", "ref_score"),
8488
(
@@ -323,6 +327,24 @@ def test_awq_with_ignored_scope(self):
323327
int4_num_nodes = self.get_num_int4_nodes(compressed_model)
324328
assert int4_num_nodes == int4_ref_num_compressed
325329

330+
def test_rope_weight_compression(self):
331+
model = self.get_RoPE_model()
332+
sz = 8
333+
n_samples = 10
334+
335+
dataset = Dataset([self.to_tensor(np.ones([1, i + 1, sz], dtype=np.float32)) for i in range(n_samples)])
336+
compressed_model = compress_weights(
337+
model,
338+
mode=CompressWeightsMode.INT4_SYM,
339+
ratio=1.0,
340+
group_size=-1,
341+
dataset=dataset,
342+
)
343+
344+
int4_ref_num_compressed = 0
345+
int4_num_nodes = self.get_num_int4_nodes(compressed_model)
346+
assert int4_num_nodes == int4_ref_num_compressed
347+
326348
@staticmethod
327349
@abstractmethod
328350
def get_reference_for_test_awq_scale_reference() -> dict[str, Tensor]:

tests/openvino/native/quantization/test_weights_compression.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from tests.openvino.native.models import MatMul
7272
from tests.openvino.native.models import ModelNamedConsts
7373
from tests.openvino.native.models import OVReferenceModel
74+
from tests.openvino.native.models import RoPEModel
7475
from tests.openvino.native.models import SequentialMatmulModel
7576
from tests.openvino.native.models import WeightsModel
7677
from tests.openvino.native.quantization.test_fq_params_calculation import REFERENCE_SCALES_DIR
@@ -1754,6 +1755,10 @@ class TestOVTemplateWeightCompression(TemplateWeightCompression):
17541755
def get_matmul_model() -> ov.Model:
17551756
return IdentityMatmul().ov_model
17561757

1758+
@staticmethod
1759+
def get_RoPE_model() -> ov.Model:
1760+
return RoPEModel().ov_model
1761+
17571762
@staticmethod
17581763
def get_sequential_matmul_model() -> ov.Model:
17591764
return SequentialMatmulModel().ov_model

0 commit comments

Comments
 (0)