Skip to content

Commit 523509c

Browse files
Modified for mixed precision
1 parent 0f94ce7 commit 523509c

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import torch
1919
from mct_quantizers import PytorchQuantizationWrapper, QuantizationTarget, \
20-
PytorchActivationQuantizationHolder
20+
PytorchActivationQuantizationHolder, PytorchPreservingActivationQuantizationHolder
2121
from mct_quantizers.common.constants import ACTIVATION_HOLDER_QUANTIZER
2222
from mct_quantizers.common.get_quantizers import get_inferable_quantizer_class
2323
from mct_quantizers.pytorch.quantizers import BasePyTorchInferableQuantizer
@@ -147,7 +147,7 @@ def _get_weights_configurable_quantizer_kwargs(self, n: BaseNode, attr: str) ->
147147
'max_candidate_idx': max_candidate_idx
148148
}
149149

150-
def mixed_precision_activation_holder(self, n: BaseNode) -> PytorchActivationQuantizationHolder:
150+
def mixed_precision_activation_holder(self, n: BaseNode, prev_n: BaseNode) -> PytorchActivationQuantizationHolder:
151151
"""
152152
Retrieve a PytorchActivationQuantizationHolder layer to use for activation quantization for a node.
153153
The layer should hold either a configurable activation quantizer, if it is quantized with mixed precision,
@@ -168,6 +168,7 @@ def mixed_precision_activation_holder(self, n: BaseNode) -> PytorchActivationQua
168168
if n.name in activation_conf_nodes_names:
169169
assert n.candidates_quantization_cfg is not None, f"Node {n.name} candidates_quantization_cfg is None"
170170
node_q_cfg_candidates = n.candidates_quantization_cfg
171+
prev_node_q_cfg_candidates = prev_n.candidates_quantization_cfg
171172

172173
# sorting the candidates by kernel attribute weights number of bits first and then by
173174
# activation number of bits (in reversed order).
@@ -181,11 +182,16 @@ def mixed_precision_activation_holder(self, n: BaseNode) -> PytorchActivationQua
181182
max_candidate_idx = max_cfg_candidates[0]
182183

183184
kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
185+
prev_activation_quantizers = [ConfigurableActivationQuantizer(**{'node_q_cfg': prev_node_q_cfg_candidates,
186+
'max_candidate_idx': max_candidate_idx,
187+
'kernel_attr': kernel_attr})] \
188+
* num_of_outputs
184189
activation_quantizers = [ConfigurableActivationQuantizer(**{'node_q_cfg': node_q_cfg_candidates,
185190
'max_candidate_idx': max_candidate_idx,
186191
'kernel_attr': kernel_attr})] \
187192
* num_of_outputs
188193
else:
194+
prev_node_act_qc = prev_n.get_unique_activation_candidates()
189195
node_act_qc = n.get_unique_activation_candidates()
190196
assert len(node_act_qc) == 1, f"Expected a single activation configuration for node '{n.name}', but found multiple ({len(node_act_qc)}) configurations."
191197
quantizer_for_node = get_inferable_quantizer_class(QuantizationTarget.Activation,
@@ -195,10 +201,19 @@ def mixed_precision_activation_holder(self, n: BaseNode) -> PytorchActivationQua
195201

196202
activation_quantizers = [quantizer_for_node(**kwargs)] * num_of_outputs
197203

204+
quantizer_for_prev_node = get_inferable_quantizer_class(QuantizationTarget.Activation,
205+
prev_node_act_qc[0].activation_quantization_cfg.activation_quantization_method,
206+
BasePyTorchInferableQuantizer)
207+
prev_n_kwargs = get_activation_inferable_quantizer_kwargs(prev_node_act_qc[0].activation_quantization_cfg)
208+
prev_activation_quantizers = [quantizer_for_prev_node(**prev_n_kwargs)] * num_of_outputs
209+
198210
# Holder by definition uses a single quantizer for the activation quantization
199211
# thus we make sure this is the only possible case (unless it's a node with no activation
200212
# quantization, which in this case has an empty list).
201213
if len(activation_quantizers) == 1:
214+
if n.is_quantization_preserving():
215+
return PytorchPreservingActivationQuantizationHolder(prev_activation_quantizers[0], quantization_bypass=True)
216+
202217
return PytorchActivationQuantizationHolder(activation_quantizers[0])
203218

204219
Logger.critical(f"PytorchActivationQuantizationHolder expects a single quantizer, but ({len(activation_quantizers)}) quantizers were found for node {n}.")# pragma: no cover

0 commit comments

Comments
 (0)