Skip to content

Commit 1f477f8

Browse files
Apply activation quantization parameters selection(2nd PR internal review) (#15)
Fixed for PR#1466 review comments.
1 parent 5e397a2 commit 1f477f8

5 files changed

Lines changed: 21 additions & 15 deletions

File tree

model_compression_toolkit/core/common/quantization/node_quantization_config.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
15+
from __future__ import annotations
1616

1717
from typing import Callable, Any, List, Tuple, Union, Dict, TYPE_CHECKING
1818
from enum import Enum, auto
@@ -32,6 +32,7 @@
3232
OpQuantizationConfig
3333

3434
if TYPE_CHECKING:
35+
from model_compression_toolkit.core.common import BaseNode
3536
from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
3637

3738
##########################################
@@ -199,15 +200,17 @@ def set_activation_quantization_params_fn(self, activation_quantization_params_f
199200
self.activation_quantization_params_fn = activation_quantization_params_fn
200201

201202
def set_activation_quantization_param(self,
202-
activation_params: dict):
203+
activation_params: dict,
204+
node: BaseNode):
203205
"""
204206
Set a quantization parameter for the node's activation.
205207
206208
Args:
207-
activation_params: Dictionary that contains weight quantization params.
209+
activation_params: Dictionary that contains activation quantization params.
210+
node: node in a graph that represents the model.
208211
209212
"""
210-
assert self.quant_mode == ActivationQuantizationMode.QUANT or self.quant_mode == ActivationQuantizationMode.FLN_QUANT
213+
assert node.is_activation_quantization_enabled() or node.is_fln_quantization()
211214
for param_name, param_value in activation_params.items():
212215
self.activation_quantization_params[param_name] = param_value
213216

model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,4 @@ def calculate_quantization_params(graph: Graph,
136136
nodes_prior_info=n.prior_info,
137137
out_stats_container=graph.get_out_stats_collector(n))
138138
# Create a NodeQuantizationConfig containing all quantization params and attach it to the node
139-
candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_params)
139+
candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_params, n)

model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def _collect_and_assign_act_threshold(graph: Graph,
5656
activation_quant_cfg=n.final_activation_quantization_cfg,
5757
nodes_prior_info=n.prior_info,
5858
out_stats_container=graph.get_out_stats_collector(n))
59-
n.final_activation_quantization_cfg.set_activation_quantization_param(activation_params)
59+
n.final_activation_quantization_cfg.set_activation_quantization_param(activation_params, n)
6060

6161

6262
def quantized_model_builder_for_second_moment_correction(graph: common.Graph,

model_compression_toolkit/core/common/substitutions/shift_negative_activation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def shift_negative_function(graph: Graph,
465465
add_node_qco[op_qc_idx])
466466

467467
candidate_qc.activation_quantization_cfg.set_activation_quantization_param({THRESHOLD: activation_threshold,
468-
SIGNED: False})
468+
SIGNED: False}, add_node)
469469

470470
candidate_qc.activation_quantization_cfg.activation_n_bits = original_non_linear_activation_nbits
471471

@@ -482,7 +482,7 @@ def shift_negative_function(graph: Graph,
482482

483483
assert activation_param.get(SIGNED) == False
484484
for candidate_qc in non_linear_node.candidates_quantization_cfg:
485-
candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_param)
485+
candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_param, non_linear_node)
486486

487487
return graph
488488

tests_pytest/pytorch_tests/unit_tests/core/common/quantization/quantization_params_generation/test_calculate_quantization_params.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,22 @@
2323
NodeActivationQuantizationConfig, NodeWeightsQuantizationConfig
2424
from model_compression_toolkit.target_platform_capabilities import OpQuantizationConfig
2525
from model_compression_toolkit.core import QuantizationConfig, QuantizationErrorMethod
26+
from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService
2627
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
2728
AttachTpcToPytorch
2829
import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
2930
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness, \
3031
AttributeQuantizationConfig
31-
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
32+
from model_compression_toolkit.core.pytorch.default_framework_info import PyTorchInfo
33+
from model_compression_toolkit.core.common.framework_info import set_fw_info, get_fw_info
34+
3235
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
3336
from model_compression_toolkit.core.common.collectors.statistics_collector import StatsCollector
3437
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, WEIGHTS_N_BITS
3538
from mct_quantizers import QuantizationMethod
3639

40+
from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
41+
3742
class TestCalculateQuantizationParams:
3843
def get_op_qco(self):
3944
# define a default quantization config for all non-specified weights attributes.
@@ -145,12 +150,11 @@ def _create_node_weights_op_cfg(self,
145150

146151
def get_test_graph(self, qem: QuantizationErrorMethod):
147152
float_model = self.get_float_model()
148-
fw_info = DEFAULT_PYTORCH_INFO
153+
set_fw_info(PyTorchInfo)
149154

150155
fw_impl = PytorchImplementation()
151156
graph = fw_impl.model_reader(float_model,
152157
self.representative_data_gen)
153-
graph.set_fw_info(fw_info)
154158

155159
quantization_config = QuantizationConfig(weights_error_method=qem)
156160

@@ -165,24 +169,23 @@ def get_test_graph(self, qem: QuantizationErrorMethod):
165169

166170
graph.node_to_out_stats_collector = dict()
167171
for id, n in enumerate(graph.nodes):
168-
n.prior_info = fw_impl.get_node_prior_info(node=n, fw_info=fw_info, graph=graph)
172+
n.prior_info = fw_impl.get_node_prior_info(node=n, graph=graph)
169173
n.candidates_quantization_cfg = []
170174
candidate_qc_a = CandidateNodeQuantizationConfig(
171175
activation_quantization_cfg=NodeActivationQuantizationConfig(qc=quantization_config, op_cfg=op_cfg,
172176
activation_quantization_fn=None,
173177
activation_quantization_params_fn=None),
174178
weights_quantization_cfg=NodeWeightsQuantizationConfig(qc=quantization_config, op_cfg=op_cfg,
175-
weights_channels_axis=(0, 1),
179+
weights_channels_axis=ChannelAxisMapping(0, 1),
176180
node_attrs_list=['weight', 'bias'])
177181
)
178182
if n.name in ['conv3']:
179183
candidate_qc_a.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
180-
candidate_qc_a.activation_quantization_cfg.activation_n_bits = 16 # set 16bit for FLN node for test.
181184
else:
182185
candidate_qc_a.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.QUANT
183186
n.candidates_quantization_cfg.append(candidate_qc_a)
184187

185-
graph.node_to_out_stats_collector[n] = StatsCollector(init_min_value=0.0, init_max_value=1.0, out_channel_axis=fw_info.out_channel_axis_mapping.get(n.type))
188+
graph.node_to_out_stats_collector[n] = StatsCollector(init_min_value=0.0, init_max_value=1.0, out_channel_axis=get_fw_info().out_channel_axis_mapping.get(n.type))
186189
graph.node_to_out_stats_collector[n].hc._n_bins = 3
187190
if n.name in ['conv1']:
188191
graph.node_to_out_stats_collector[n].hc._bins = np.array([0.4, 0.8, 1.2])

0 commit comments

Comments
 (0)