Skip to content

Commit 4a244f3

Browse files
Fix common pre-commit
1 parent 81ad932 commit 4a244f3

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

nncf/quantization/algorithms/accuracy_control/ranker.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from nncf.quantization.algorithms.accuracy_control.evaluator import Evaluator
2929
from nncf.quantization.algorithms.accuracy_control.rank_functions import create_normalized_mse_func
3030
from nncf.quantization.algorithms.accuracy_control.subset_selection import select_subset
31+
from nncf.quantization.passes import filter_constant_nodes
3132
from nncf.quantization.passes import remove_shapeof_subgraphs
3233

3334
TModel = TypeVar("TModel")
@@ -98,8 +99,11 @@ def find_groups_of_quantizers_to_rank(self, quantized_model_graph: NNCFGraph) ->
9899
if x.metatype in self._algo_backend.get_quantizer_metatypes()
99100
]
100101

102+
quantized_model_graph_without_shapeof = filter_constant_nodes(
103+
deepcopy(quantized_model_graph), self._algo_backend.get_const_metatypes()
104+
)
101105
quantized_model_graph_without_shapeof = remove_shapeof_subgraphs(
102-
deepcopy(quantized_model_graph), self._algo_backend.get_shapeof_metatypes()
106+
quantized_model_graph_without_shapeof, self._algo_backend.get_shapeof_metatypes()
103107
)
104108

105109
for quantizer_node in reversed(quantizers):

tests/common/quantization/metatypes.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@ class BatchNormTestMetatype(TestMetatype):
3636
@METATYPES_FOR_TEST.register()
3737
class Conv2dTestMetatype(TestMetatype):
3838
name = "conv2d"
39+
input_edges_num_expected = 2
3940

4041

4142
@METATYPES_FOR_TEST.register()
4243
class MatMulTestMetatype(TestMetatype):
4344
name = "matmul"
45+
input_edges_num_expected = 2
4446

4547

4648
@METATYPES_FOR_TEST.register()
@@ -76,6 +78,7 @@ class CatTestMetatype(TestMetatype):
7678
@METATYPES_FOR_TEST.register()
7779
class LinearTestMetatype(TestMetatype):
7880
name = "linear"
81+
input_edges_num_expected = 2
7982

8083

8184
@METATYPES_FOR_TEST.register()
@@ -96,11 +99,13 @@ class IdentityTestMetatype(TestMetatype):
9699
@METATYPES_FOR_TEST.register()
97100
class ReshapeTestMetatype(TestMetatype):
98101
name = "reshape"
102+
input_edges_num_expected = 2
99103

100104

101105
@METATYPES_FOR_TEST.register()
102106
class QuantizerTestMetatype(TestMetatype):
103107
name = "quantizer"
108+
input_edges_num_expected = 2
104109

105110

106111
@METATYPES_FOR_TEST.register()
@@ -116,6 +121,7 @@ class ReluTestMetatype(TestMetatype):
116121
@METATYPES_FOR_TEST.register()
117122
class AddTestMetatype(TestMetatype):
118123
name = "add"
124+
input_edges_num_expected = 2
119125

120126

121127
@METATYPES_FOR_TEST.register()
@@ -126,16 +132,19 @@ class ShapeOfTestMetatype(TestMetatype):
126132
@METATYPES_FOR_TEST.register()
127133
class PowerTestMetatype(TestMetatype):
128134
name = "power"
135+
input_edges_num_expected = 2
129136

130137

131138
@METATYPES_FOR_TEST.register()
132139
class MultiplyTestMetatype(TestMetatype):
133140
name = "multiply"
141+
input_edges_num_expected = 2
134142

135143

136144
@METATYPES_FOR_TEST.register()
137145
class InterpolateTestMetatype(TestMetatype):
138146
name = "interpolate"
147+
input_edges_num_expected = 3
139148

140149

141150
@METATYPES_FOR_TEST.register()

tests/common/quantization/test_quantizer_removal.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from nncf.common.graph import NNCFGraph
1919
from nncf.common.graph.layer_attributes import Dtype
2020
from nncf.common.quantization.quantizer_removal import find_quantizer_nodes_to_cut
21+
from nncf.quantization.passes import filter_constant_nodes
2122
from nncf.quantization.passes import remove_shapeof_subgraphs
2223
from tests.common.quantization.metatypes import CONSTANT_METATYPES
2324
from tests.common.quantization.metatypes import METATYPES_FOR_TEST
@@ -226,7 +227,8 @@ def create_test_params():
226227
@pytest.mark.parametrize("nncf_graph,test_case", create_test_params())
227228
def test_find_quantizer_nodes_to_cut(nncf_graph: NNCFGraph, test_case: TestCase):
228229
quantizer_node = nncf_graph.get_node_by_name(test_case.node_name)
229-
nncf_graph_without_shapeof = remove_shapeof_subgraphs(deepcopy(nncf_graph), SHAPEOF_METATYPES)
230+
nncf_graph_without_shapeof = filter_constant_nodes(deepcopy(nncf_graph), CONSTANT_METATYPES)
231+
nncf_graph_without_shapeof = remove_shapeof_subgraphs(nncf_graph_without_shapeof, SHAPEOF_METATYPES)
230232
nodes, ops = find_quantizer_nodes_to_cut(
231233
nncf_graph_without_shapeof,
232234
quantizer_node,

0 commit comments

Comments
 (0)