|
10 | 10 | # limitations under the License. |
11 | 11 | from abc import abstractmethod |
12 | 12 | from collections import Counter |
| 13 | +from copy import deepcopy |
13 | 14 | from typing import Dict |
14 | 15 |
|
15 | 16 | import pytest |
|
27 | 28 | from nncf.quantization.advanced_parameters import OverflowFix |
28 | 29 | from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization |
29 | 30 | from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization |
| 31 | +from nncf.quantization.passes import transform_to_inference_graph |
30 | 32 | from nncf.quantization.range_estimator import RangeEstimatorParametersSet |
31 | 33 | from nncf.scopes import IgnoredScope |
32 | 34 | from tests.common.quantization.metatypes import Conv2dTestMetatype |
@@ -262,13 +264,14 @@ def test_quantization_points_overflow_fix(self, overflow_fix, affected_target_po |
262 | 264 | @pytest.mark.parametrize("validate_scopes", (True, False)) |
263 | 265 | def test_validate_scope(self, test_params, validate_scopes): |
264 | 266 | nncf_graph = test_params["test_model_type_pass"]["nncf_graph"] |
| 267 | + inference_nncf_graph = transform_to_inference_graph(deepcopy(nncf_graph), []) |
265 | 268 | ignored_patterns = test_params["test_model_type_pass"]["ignored_patterns"] |
266 | 269 | algo = MinMaxQuantization( |
267 | 270 | ignored_scope=IgnoredScope(names=["some_node"], validate=validate_scopes), |
268 | 271 | ) |
269 | 272 | algo._backend_entity = self.get_algo_backend() |
270 | 273 | if validate_scopes: |
271 | 274 | with pytest.raises(RuntimeError, match="Ignored nodes with name"): |
272 | | - algo._get_ignored_names(nncf_graph, ignored_patterns) |
| 275 | + algo._get_ignored_names(nncf_graph, inference_nncf_graph, ignored_patterns) |
273 | 276 | else: |
274 | | - algo._get_ignored_names(nncf_graph, ignored_patterns) |
| 277 | + algo._get_ignored_names(nncf_graph, inference_nncf_graph, ignored_patterns) |
0 commit comments