Skip to content

Commit 14516ce

Browse files
authored
remove duplicate candidates from node added by SNC (#1478)
1 parent 06d4589 commit 14516ce

3 files changed

Lines changed: 6 additions & 4 deletions

File tree

model_compression_toolkit/core/common/graph/base_graph.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -875,8 +875,7 @@ def override_fused_node_activation_quantization_candidates(self):
875875
def update(qc):
876876
qc.activation_quantization_cfg = NodeActivationQuantizationConfig(fusing_op_quantization_cfg)
877877
qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
878-
node.quantization_cfg.update_all(update)
879-
node.quantization_cfg.remove_duplicates()
878+
node.quantization_cfg.update_all(update, remove_duplicates=True)
880879
else:
881880
node.quantization_cfg.update_activation_quantization_mode(ActivationQuantizationMode.FLN_NO_QUANT)
882881
# Remove duplicate candidates. We cannot compare whole candidates since activation configs might not

model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,20 @@ class NodeQuantizationConfig:
4646

4747
validate: InitVar[bool] = True
4848

49-
def update_all(self, update_fn: Callable[[CandidateNodeQuantizationConfig], None]):
49+
def update_all(self, update_fn: Callable[[CandidateNodeQuantizationConfig], None], remove_duplicates: bool = True):
5050
"""
5151
Apply update function on the base config and all candidates configs.
5252
5353
Args:
5454
update_fn: function to apply.
55+
remove_duplicates: remove duplicate candidates.
5556
"""
5657
if self.base_quantization_cfg:
5758
update_fn(self.base_quantization_cfg)
5859
for cfg in self.candidates_quantization_cfg:
5960
update_fn(cfg)
61+
if remove_duplicates:
62+
self.remove_duplicates()
6063

6164
def update_activation_quantization_mode(self, mode: ActivationQuantizationMode):
6265
"""

model_compression_toolkit/core/common/substitutions/shift_negative_activation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def update(c):
463463
c.activation_quantization_cfg.set_activation_quantization_param({THRESHOLD: activation_threshold,
464464
SIGNED: False})
465465

466-
add_node.quantization_cfg.update_all(update)
466+
add_node.quantization_cfg.update_all(update, remove_duplicates=True)
467467

468468
# Add the new padding node to a fused op with the op2d.
469469
if pad_node:

0 commit comments

Comments
 (0)