1717
1818import torch
1919from mct_quantizers import PytorchQuantizationWrapper , QuantizationTarget , \
20- PytorchActivationQuantizationHolder
20+ PytorchActivationQuantizationHolder , PytorchPreservingActivationQuantizationHolder
2121from mct_quantizers .common .constants import ACTIVATION_HOLDER_QUANTIZER
2222from mct_quantizers .common .get_quantizers import get_inferable_quantizer_class
2323from mct_quantizers .pytorch .quantizers import BasePyTorchInferableQuantizer
@@ -147,7 +147,7 @@ def _get_weights_configurable_quantizer_kwargs(self, n: BaseNode, attr: str) ->
147147 'max_candidate_idx' : max_candidate_idx
148148 }
149149
150- def mixed_precision_activation_holder (self , n : BaseNode ) -> PytorchActivationQuantizationHolder :
150+ def mixed_precision_activation_holder (self , n : BaseNode , prev_n : BaseNode ) -> PytorchActivationQuantizationHolder :
151151 """
152152 Retrieve a PytorchActivationQuantizationHolder layer to use for activation quantization for a node.
153153 The layer should hold either a configurable activation quantizer, if it is quantized with mixed precision,
@@ -168,6 +168,7 @@ def mixed_precision_activation_holder(self, n: BaseNode) -> PytorchActivationQua
168168 if n .name in activation_conf_nodes_names :
169169 assert n .candidates_quantization_cfg is not None , f"Node { n .name } candidates_quantization_cfg is None"
170170 node_q_cfg_candidates = n .candidates_quantization_cfg
171+ prev_node_q_cfg_candidates = prev_n .candidates_quantization_cfg
171172
172173 # sorting the candidates by kernel attribute weights number of bits first and then by
173174 # activation number of bits (in reversed order).
@@ -181,11 +182,16 @@ def mixed_precision_activation_holder(self, n: BaseNode) -> PytorchActivationQua
181182 max_candidate_idx = max_cfg_candidates [0 ]
182183
183184 kernel_attr = self .fw_info .get_kernel_op_attributes (n .type )[0 ]
185+ prev_activation_quantizers = [ConfigurableActivationQuantizer (** {'node_q_cfg' : prev_node_q_cfg_candidates ,
186+ 'max_candidate_idx' : max_candidate_idx ,
187+ 'kernel_attr' : kernel_attr })] \
188+ * num_of_outputs
184189 activation_quantizers = [ConfigurableActivationQuantizer (** {'node_q_cfg' : node_q_cfg_candidates ,
185190 'max_candidate_idx' : max_candidate_idx ,
186191 'kernel_attr' : kernel_attr })] \
187192 * num_of_outputs
188193 else :
194+ prev_node_act_qc = prev_n .get_unique_activation_candidates ()
189195 node_act_qc = n .get_unique_activation_candidates ()
190196 assert len (node_act_qc ) == 1 , f"Expected a single activation configuration for node '{ n .name } ', but found multiple ({ len (node_act_qc )} ) configurations."
191197 quantizer_for_node = get_inferable_quantizer_class (QuantizationTarget .Activation ,
@@ -195,10 +201,19 @@ def mixed_precision_activation_holder(self, n: BaseNode) -> PytorchActivationQua
195201
196202 activation_quantizers = [quantizer_for_node (** kwargs )] * num_of_outputs
197203
204+ quantizer_for_prev_node = get_inferable_quantizer_class (QuantizationTarget .Activation ,
205+ prev_node_act_qc [0 ].activation_quantization_cfg .activation_quantization_method ,
206+ BasePyTorchInferableQuantizer )
207+ prev_n_kwargs = get_activation_inferable_quantizer_kwargs (prev_node_act_qc [0 ].activation_quantization_cfg )
208+ prev_activation_quantizers = [quantizer_for_prev_node (** prev_n_kwargs )] * num_of_outputs
209+
198210 # Holder by definition uses a single quantizer for the activation quantization
199211 # thus we make sure this is the only possible case (unless it's a node with no activation
200212 # quantization, which in this case has an empty list).
201213 if len (activation_quantizers ) == 1 :
214+ if n .is_quantization_preserving ():
215+ return PytorchPreservingActivationQuantizationHolder (prev_activation_quantizers [0 ], quantization_bypass = True )
216+
202217 return PytorchActivationQuantizationHolder (activation_quantizers [0 ])
203218
204219 Logger .critical (f"PytorchActivationQuantizationHolder expects a single quantizer, but ({ len (activation_quantizers )} ) quantizers were found for node { n } ." )# pragma: no cover
0 commit comments