|
9 | 9 | # See the License for the specific language governing permissions and |
10 | 10 | # limitations under the License. |
11 | 11 |
|
| 12 | + |
12 | 13 | from typing import List |
13 | 14 |
|
14 | 15 | import numpy as np |
15 | 16 | import pytest |
16 | 17 |
|
| 18 | +from nncf.common.graph.graph import NNCFGraph |
17 | 19 | from nncf.quantization.algorithms.accuracy_control.rank_functions import normalized_mse |
| 20 | +from nncf.quantization.algorithms.accuracy_control.ranker import GroupToRank |
| 21 | +from nncf.quantization.algorithms.accuracy_control.ranker import Ranker |
18 | 22 | from nncf.quantization.algorithms.accuracy_control.subset_selection import get_subset_indices |
| 23 | +from tests.common.accuracy_control.backend import AABackendForTests |
| 24 | +from tests.common.quantization.test_quantizer_removal import GRAPHS as AA_GRAPHS_DESCR |
| 25 | +from tests.common.quantization.test_quantizer_removal import create_nncf_graph as aa_create_nncf_graph |
19 | 26 |
|
20 | 27 |
|
21 | 28 | def create_fp32_tensor_1d(items): |
@@ -77,3 +84,41 @@ def test_normalized_mse(x_ref: np.ndarray, x_approx: np.ndarray, expected_nmse: |
77 | 84 | def test_get_subset_indices(errors: List[float], subset_size: int, expected_indices: List[int]): |
78 | 85 | actual_indices = get_subset_indices(errors, subset_size) |
79 | 86 | assert expected_indices == actual_indices |
| 87 | + |
| 88 | + |
| 89 | +@pytest.mark.parametrize( |
| 90 | + "nncf_graph_name,ref_groups", |
| 91 | + [ |
| 92 | + ( |
| 93 | + "simple_graph", |
| 94 | + [ |
| 95 | + GroupToRank(["quantizer_139", "quantizer_162", "quantizer_119"], ["add_117", "conv2d_161"]), |
| 96 | + GroupToRank(["quantizer_153", "quantizer_147"], ["conv2d_146"]), |
| 97 | + GroupToRank(["quantizer_134", "quantizer_128"], ["conv2d_127"]), |
| 98 | + ], |
| 99 | + ), |
| 100 | + ( |
| 101 | + "graph_with_shapeof", |
| 102 | + [ |
| 103 | + GroupToRank(["quantizer_105"], ["interpolate_115"]), |
| 104 | + GroupToRank(["quantizer_710", "quantizer_93"], ["multiply_99"]), |
| 105 | + GroupToRank(["quantizer_82"], ["power_87"]), |
| 106 | + ], |
| 107 | + ), |
| 108 | + ], |
| 109 | +) |
| 110 | +def test_find_groups_of_quantizers_to_rank(nncf_graph_name: NNCFGraph, ref_groups: List[GroupToRank]): |
| 111 | + ranker = Ranker(1, tuple(), AABackendForTests, None) |
| 112 | + nncf_graph = aa_create_nncf_graph(AA_GRAPHS_DESCR[nncf_graph_name]) |
| 113 | + ret_val = ranker.find_groups_of_quantizers_to_rank(nncf_graph) |
| 114 | + assert len(ret_val) == len(ref_groups) |
| 115 | + # Can zip as qauantizers are topologically sorted |
| 116 | + for actual_group, ref_group in zip(ret_val, ref_groups): |
| 117 | + for attr in ["quantizers", "operations"]: |
| 118 | + acutal_attr_value = getattr(actual_group, attr) |
| 119 | + ref_attr_value = getattr(ref_group, attr) |
| 120 | + |
| 121 | + assert len(acutal_attr_value) == len(ref_attr_value) |
| 122 | + actual_node_names = [n.node_name for n in acutal_attr_value] |
| 123 | + for ref_node_name in ref_attr_value: |
| 124 | + assert ref_node_name in actual_node_names |
0 commit comments