Skip to content

Commit 34fec87

Browse files
committed
modify for a PR review comment(SonySemiconductorSolutions#1466 (comment))
1 parent b5f409b commit 34fec87

4 files changed

Lines changed: 13 additions & 8 deletions

File tree

model_compression_toolkit/core/common/quantization/node_quantization_config.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
15+
from __future__ import annotations
1616

1717
from typing import Callable, Any, List, Tuple, Union, Dict, TYPE_CHECKING
1818
from enum import Enum, auto
@@ -31,6 +31,9 @@
3131
AttributeQuantizationConfig, \
3232
OpQuantizationConfig
3333

34+
if TYPE_CHECKING:
35+
from model_compression_toolkit.core.common import BaseNode
36+
3437
if TYPE_CHECKING:
3538
from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
3639

@@ -199,15 +202,17 @@ def set_activation_quantization_params_fn(self, activation_quantization_params_f
199202
self.activation_quantization_params_fn = activation_quantization_params_fn
200203

201204
def set_activation_quantization_param(self,
202-
activation_params: dict):
205+
activation_params: dict,
206+
node: BaseNode):
203207
"""
204208
Set a quantization parameter for the node's activation.
205209
206210
Args:
207-
activation_params: Dictionary that contains weight quantization params.
211+
activation_params: Dictionary that contains activation quantization params.
212+
node: node in a graph that represents the model.
208213
209214
"""
210-
assert self.quant_mode == ActivationQuantizationMode.QUANT or self.quant_mode == ActivationQuantizationMode.FLN_QUANT
215+
assert node.is_activation_quantization_enabled() or node.is_fln_quantization()
211216
for param_name, param_value in activation_params.items():
212217
self.activation_quantization_params[param_name] = param_value
213218

model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,4 @@ def calculate_quantization_params(graph: Graph,
136136
nodes_prior_info=n.prior_info,
137137
out_stats_container=graph.get_out_stats_collector(n))
138138
# Create a NodeQuantizationConfig containing all quantization params and attach it to the node
139-
candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_params)
139+
candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_params, n)

model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def _collect_and_assign_act_threshold(graph: Graph,
5656
activation_quant_cfg=n.final_activation_quantization_cfg,
5757
nodes_prior_info=n.prior_info,
5858
out_stats_container=graph.get_out_stats_collector(n))
59-
n.final_activation_quantization_cfg.set_activation_quantization_param(activation_params)
59+
n.final_activation_quantization_cfg.set_activation_quantization_param(activation_params, n)
6060

6161

6262
def quantized_model_builder_for_second_moment_correction(graph: common.Graph,

model_compression_toolkit/core/common/substitutions/shift_negative_activation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def shift_negative_function(graph: Graph,
465465
add_node_qco[op_qc_idx])
466466

467467
candidate_qc.activation_quantization_cfg.set_activation_quantization_param({THRESHOLD: activation_threshold,
468-
SIGNED: False})
468+
SIGNED: False}, add_node)
469469

470470
candidate_qc.activation_quantization_cfg.activation_n_bits = original_non_linear_activation_nbits
471471

@@ -482,7 +482,7 @@ def shift_negative_function(graph: Graph,
482482

483483
assert activation_param.get(SIGNED) == False
484484
for candidate_qc in non_linear_node.candidates_quantization_cfg:
485-
candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_param)
485+
candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_param, non_linear_node)
486486

487487
return graph
488488

0 commit comments

Comments
 (0)