Skip to content

Commit 02641fd

Browse files
irenabirenab
authored andcommitted
fixes after rebase
1 parent 7d4598e commit 02641fd

5 files changed

Lines changed: 5 additions & 23 deletions

File tree

model_compression_toolkit/core/common/graph/base_graph.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -873,11 +873,7 @@ def override_fused_node_activation_quantization_candidates(self):
873873
fusing_op_quantization_cfg = self.fusing_info.get_fused_op_quantization_config(fused_node_op_id)
874874
if fusing_op_quantization_cfg is not None and fusing_op_quantization_cfg.enable_activation_quantization:
875875
def update(qc):
876-
qc.activation_quantization_cfg = NodeActivationQuantizationConfig(
877-
fusing_op_quantization_cfg,
878-
qc.activation_quantization_cfg.activation_quantization_fn,
879-
qc.activation_quantization_cfg.activation_quantization_params_fn
880-
)
876+
qc.activation_quantization_cfg = NodeActivationQuantizationConfig(fusing_op_quantization_cfg)
881877
qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
882878
node.quantization_cfg.update_all(update)
883879
node.quantization_cfg.remove_duplicates()

model_compression_toolkit/core/common/network_editors/actions.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,8 @@
2020
from mct_quantizers import QuantizationMethod
2121
from model_compression_toolkit.core.common import Graph
2222
from model_compression_toolkit.logger import Logger
23-
24-
2523
from model_compression_toolkit.core.common.graph.base_node import BaseNode
26-
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
27-
get_weights_quantization_fn
24+
2825

2926
_EditRule = namedtuple('EditRule', 'filter action')
3027

model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _collect_and_assign_act_threshold(graph: Graph,
5050
for _data in tqdm(representative_data_gen()):
5151
mi.infer(_data)
5252

53-
for n in list(graph.nodes):
53+
for n in graph.nodes:
5454
if n.is_activation_quantization_enabled():
5555
activation_params = compute_activation_qparams(activation_quant_cfg=n.final_activation_quantization_cfg,
5656
node_prior_info=n.prior_info,

model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from model_compression_toolkit.core import QuantizationConfig
1919
from model_compression_toolkit.core.common import BaseNode, Graph
2020
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
21-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
2221
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_activation_quantization_fn
2322

2423

tests_pytest/common_tests/unit_tests/core/graph/test_base_graph.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,18 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
import itertools
16-
from copy import deepcopy
1716

1817
import pytest
19-
from unittest.mock import Mock, PropertyMock
18+
from unittest.mock import Mock
2019

2120
from mct_quantizers import QuantizationMethod
2221
from model_compression_toolkit.core.common import Graph
23-
from model_compression_toolkit.core.common.graph.base_node import BaseNode
2422
from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfo
2523
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness
26-
from tests.common_tests.helpers.generate_test_tpc import generate_test_attr_configs, generate_test_op_qc
2724
from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode, NodeActivationQuantizationConfig
2825
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
2926
CandidateNodeQuantizationConfig, NodeQuantizationConfig
30-
from model_compression_toolkit.core.common.quantization.quantization_params_generation.power_of_two_selection import power_of_two_selection_histogram
31-
from model_compression_toolkit.core.common.quantization.quantization_params_generation.symmetric_selection import symmetric_selection_histogram
32-
from model_compression_toolkit.core import QuantizationErrorMethod
33-
from tests_pytest._test_util.graph_builder_utils import build_node, build_nbits_qc
27+
from tests_pytest._test_util.graph_builder_utils import build_node
3428

3529

3630
def build_mock_fusing_info(nodes, idx):
@@ -42,7 +36,6 @@ def build_mock_fusing_info(nodes, idx):
4236
OpQCfg.activation_n_bits = 16
4337
OpQCfg.signedness = Signedness.AUTO
4438
OpQCfg.activation_quantization_method = QuantizationMethod.POWER_OF_TWO
45-
OpQCfg.activation_quantization_params_fn = power_of_two_selection_histogram
4639
OpQCfg.quantization_preserving = False
4740

4841
fusing_info = Mock(spec=FusingInfo)
@@ -70,8 +63,6 @@ def eq(self_, other):
7063
a_cfgs = [Mock(spec=NodeActivationQuantizationConfig,
7164
quant_mode=Mock(),
7265
activation_n_bits=b,
73-
activation_quantization_fn=symmetric_selection_histogram,
74-
activation_quantization_params_fn=power_of_two_selection_histogram,
7566
__eq__=eq) for b in [5, 6]]
7667

7768
qcs = [CandidateNodeQuantizationConfig(a_cfg, w_cfg) for a_cfg, w_cfg in itertools.product(a_cfgs, w_cfgs)]
@@ -124,7 +115,6 @@ def test_override_fused_node_activation_quantization_candidates(self, idx, patch
124115
assert qc.activation_quantization_cfg.activation_n_bits == 16
125116
assert qc.activation_quantization_cfg.signedness == Signedness.AUTO
126117
assert qc.activation_quantization_cfg.activation_quantization_method == QuantizationMethod.POWER_OF_TWO
127-
assert qc.activation_quantization_cfg.activation_quantization_params_fn == power_of_two_selection_histogram
128118
assert qc.weights_quantization_cfg == w_cfgs[i]
129119
base_cfg0 = nodes[0].quantization_cfg.base_quantization_cfg
130120
assert base_cfg0.activation_quantization_cfg.activation_n_bits == 16

0 commit comments

Comments
 (0)