@@ -335,13 +335,35 @@ def compute_activations_utilization(self,
335335 """
336336 return self .compute_activation_utilization_by_cut (target_criterion , bitwidth_mode , act_qcs )
337337
338+ def _extract_qc (self , n : BaseNode , act_qcs : Optional [ActivationQCfgPerNode ] = None
339+ ) -> NodeActivationQuantizationConfig | None :
340+ """
341+ Extract quantization config the activation configs dictionary is provided. If node is quantization
342+ preserving, extract the quantization config from the preceding activation quantized node (i.e.
343+ the Quantization the original node preserves).
344+
345+ Args:
346+ n: Node to extract qc for.
347+ act_qcs: custom activations quantization configuration. If not provided, the default
348+ configuration will be extracted from the node.
349+
350+ Returns:
351+ The relevant quantization config.
352+ """
353+ if act_qcs :
354+ assert not n .is_quantization_preserving () or n .name not in act_qcs , \
355+ f"Quantization preserving node { n .name } should not have a qc for this computation."
356+ return act_qcs .get (self .graph .retrieve_preserved_quantization_node (n ).name )
357+ return None
358+
338359 def compute_activation_utilization_by_cut (self ,
339360 target_criterion : TargetInclusionCriterion ,
340361 bitwidth_mode : BitwidthMode ,
341362 act_qcs : Optional [ActivationQCfgPerNode ] = None ) \
342363 -> Tuple [float , Dict [Cut , Utilization ], Dict [Cut , Dict [BaseNode , Utilization ]]]:
343364 """
344- Compute graph activation cuts utilization.
365+ Compute graph activation cuts utilization. If activation quantization configs are provided, then for
366+ quantization preserving nodes, get the previous quantized activation node bit-width.
345367
346368 Args:
347369 target_criterion: criterion to include weights for computation.
@@ -369,7 +391,7 @@ def compute_activation_utilization_by_cut(self,
369391 if not cut_target_nodes :
370392 continue
371393 for n in cut_target_nodes :
372- qc = act_qcs . get ( self .graph . retrieve_preserved_quantization_node ( n ). name ) if act_qcs else None
394+ qc = self ._extract_qc ( n , act_qcs )
373395 util_per_cut_per_node [cut ][n .name ] = self .compute_node_activation_tensor_utilization (n , target_criterion ,
374396 bitwidth_mode , qc )
375397 util_per_cut [cut ] = sum (util_per_cut_per_node [cut ].values ()) # type: ignore
@@ -384,7 +406,8 @@ def compute_activation_tensors_utilization(self,
384406 include_reused = False ) \
385407 -> Tuple [float , Dict [NodeName , Utilization ]]:
386408 """
387- Compute resource utilization for graph's activations tensors.
409+ Compute resource utilization for graph's activations tensors. If activation quantization configs are provided, then for
410+ quantization preserving nodes, get the previous quantized activation node bit-width.
388411
389412 Args:
390413 target_criterion: criterion to include weights for computation.
@@ -405,7 +428,7 @@ def compute_activation_tensors_utilization(self,
405428
406429 util_per_node : Dict [NodeName , Utilization ] = {}
407430 for n in self ._topo_sort (nodes ):
408- qc = act_qcs . get ( n . name ) if act_qcs else None
431+ qc = self . _extract_qc ( n , act_qcs )
409432 util = self .compute_node_activation_tensor_utilization (n , None , bitwidth_mode , qc )
410433 util_per_node [n .name ] = util
411434
@@ -440,7 +463,7 @@ def compute_node_activation_tensor_utilization(self,
440463 return Utilization (0 , 0 )
441464
442465 size = self ._act_tensors_size [n .name ]
443- nbits = self ._get_activation_nbits (self . graph . retrieve_preserved_quantization_node ( n ) , bitwidth_mode , qc )
466+ nbits = self ._get_activation_nbits (n , bitwidth_mode , qc )
444467 bytes_ = size * nbits / 8
445468 return Utilization (size , bytes_ )
446469
0 commit comments