Skip to content

Commit 4266ed6

Browse files
test_find_groups_of_quantizers_to_rank is presented
1 parent 21dbece commit 4266ed6

File tree

2 files changed

+122
-0
lines changed

2 files changed

+122
-0
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) 2023 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from typing import Any, List, Optional
13+
14+
from nncf.common.graph.graph import NNCFGraph
15+
from nncf.common.graph.graph import NNCFNode
16+
from nncf.common.graph.operator_metatypes import OperatorMetatype
17+
from nncf.quantization.algorithms.accuracy_control.backend import AccuracyControlAlgoBackend
18+
from nncf.quantization.algorithms.accuracy_control.backend import TModel
19+
from tests.common.quantization.metatypes import CONSTANT_METATYPES
20+
from tests.common.quantization.metatypes import QUANTIZABLE_METATYPES
21+
from tests.common.quantization.metatypes import QUANTIZE_AGNOSTIC_METATYPES
22+
from tests.common.quantization.metatypes import QUANTIZER_METATYPES
23+
from tests.common.quantization.metatypes import ShapeOfTestMetatype
24+
25+
26+
class AABackendForTests(AccuracyControlAlgoBackend):
27+
@staticmethod
28+
def get_quantizer_metatypes() -> List[OperatorMetatype]:
29+
return QUANTIZER_METATYPES
30+
31+
@staticmethod
32+
def get_const_metatypes() -> List[OperatorMetatype]:
33+
return CONSTANT_METATYPES
34+
35+
@staticmethod
36+
def get_quantizable_metatypes() -> List[OperatorMetatype]:
37+
return QUANTIZABLE_METATYPES
38+
39+
@staticmethod
40+
def get_start_nodes_for_activation_path_tracing(nncf_graph: NNCFGraph) -> List[NNCFNode]:
41+
return nncf_graph.get_input_nodes()
42+
43+
@staticmethod
44+
def get_quantize_agnostic_metatypes() -> List[OperatorMetatype]:
45+
return QUANTIZE_AGNOSTIC_METATYPES
46+
47+
@staticmethod
48+
def get_shapeof_metatypes() -> List[OperatorMetatype]:
49+
return [ShapeOfTestMetatype]
50+
51+
@staticmethod
52+
def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
53+
return False
54+
55+
@staticmethod
56+
def is_node_with_weight(node: NNCFNode) -> bool:
57+
return False
58+
59+
@staticmethod
60+
def get_bias_value(node_with_bias: NNCFNode, nncf_graph: NNCFGraph, model: TModel) -> Any:
61+
return None
62+
63+
@staticmethod
64+
def get_weight_value(node_with_weight: NNCFNode, model: TModel, port_id: int) -> Any:
65+
return None
66+
67+
@staticmethod
68+
def get_weight_tensor_port_ids(node: NNCFNode) -> List[Optional[int]]:
69+
return None
70+
71+
@staticmethod
72+
def get_model_size(model: TModel) -> int:
73+
return 0
74+
75+
@staticmethod
76+
def prepare_for_inference(model: TModel) -> TModel:
77+
return model

tests/common/accuracy_control/test_ranking.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,20 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
1213
from typing import List
1314

1415
import numpy as np
1516
import pytest
1617

18+
from nncf.common.graph.graph import NNCFGraph
1719
from nncf.quantization.algorithms.accuracy_control.rank_functions import normalized_mse
20+
from nncf.quantization.algorithms.accuracy_control.ranker import GroupToRank
21+
from nncf.quantization.algorithms.accuracy_control.ranker import Ranker
1822
from nncf.quantization.algorithms.accuracy_control.subset_selection import get_subset_indices
23+
from tests.common.accuracy_control.backend import AABackendForTests
24+
from tests.common.quantization.test_quantizer_removal import GRAPHS as AA_GRAPHS_DESCR
25+
from tests.common.quantization.test_quantizer_removal import create_nncf_graph as aa_create_nncf_graph
1926

2027

2128
def create_fp32_tensor_1d(items):
@@ -77,3 +84,41 @@ def test_normalized_mse(x_ref: np.ndarray, x_approx: np.ndarray, expected_nmse:
7784
def test_get_subset_indices(errors: List[float], subset_size: int, expected_indices: List[int]):
7885
actual_indices = get_subset_indices(errors, subset_size)
7986
assert expected_indices == actual_indices
87+
88+
89+
@pytest.mark.parametrize(
90+
"nncf_graph_name,ref_groups",
91+
[
92+
(
93+
"simple_graph",
94+
[
95+
GroupToRank(["quantizer_139", "quantizer_162", "quantizer_119"], ["add_117", "conv2d_161"]),
96+
GroupToRank(["quantizer_153", "quantizer_147"], ["conv2d_146"]),
97+
GroupToRank(["quantizer_134", "quantizer_128"], ["conv2d_127"]),
98+
],
99+
),
100+
(
101+
"graph_with_shapeof",
102+
[
103+
GroupToRank(["quantizer_105"], ["interpolate_115"]),
104+
GroupToRank(["quantizer_710", "quantizer_93"], ["multiply_99"]),
105+
GroupToRank(["quantizer_82"], ["power_87"]),
106+
],
107+
),
108+
],
109+
)
110+
def test_find_groups_of_quantizers_to_rank(nncf_graph_name: NNCFGraph, ref_groups: List[GroupToRank]):
111+
ranker = Ranker(1, tuple(), AABackendForTests, None)
112+
nncf_graph = aa_create_nncf_graph(AA_GRAPHS_DESCR[nncf_graph_name])
113+
ret_val = ranker.find_groups_of_quantizers_to_rank(nncf_graph)
114+
assert len(ret_val) == len(ref_groups)
115+
# Can zip as qauantizers are topologically sorted
116+
for actual_group, ref_group in zip(ret_val, ref_groups):
117+
for attr in ["quantizers", "operations"]:
118+
acutal_attr_value = getattr(actual_group, attr)
119+
ref_attr_value = getattr(ref_group, attr)
120+
121+
assert len(acutal_attr_value) == len(ref_attr_value)
122+
actual_node_names = [n.node_name for n in acutal_attr_value]
123+
for ref_node_name in ref_attr_value:
124+
assert ref_node_name in actual_node_names

0 commit comments

Comments
 (0)