Skip to content

Commit 9a2e4ab

Browse files
committed
Fix PR comments
1 parent 354bb5a commit 9a2e4ab

File tree

1 file changed

+28
-5
lines changed

1 file changed

+28
-5
lines changed

model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -335,13 +335,35 @@ def compute_activations_utilization(self,
335335
"""
336336
return self.compute_activation_utilization_by_cut(target_criterion, bitwidth_mode, act_qcs)
337337

338+
def _extract_qc(self, n: BaseNode, act_qcs: Optional[ActivationQCfgPerNode] = None
339+
) -> NodeActivationQuantizationConfig | None:
340+
"""
341+
Extract quantization config the activation configs dictionary is provided. If node is quantization
342+
preserving, extract the quantization config from the preceding activation quantized node (i.e.
343+
the Quantization the original node preserves).
344+
345+
Args:
346+
n: Node to extract qc for.
347+
act_qcs: custom activations quantization configuration. If not provided, the default
348+
configuration will be extracted from the node.
349+
350+
Returns:
351+
The relevant quantization config.
352+
"""
353+
if act_qcs:
354+
assert not n.is_quantization_preserving() or n.name not in act_qcs, \
355+
f"Quantization preserving node {n.name} should not have a qc for this computation."
356+
return act_qcs.get(self.graph.retrieve_preserved_quantization_node(n).name)
357+
return None
358+
338359
def compute_activation_utilization_by_cut(self,
339360
target_criterion: TargetInclusionCriterion,
340361
bitwidth_mode: BitwidthMode,
341362
act_qcs: Optional[ActivationQCfgPerNode] = None) \
342363
-> Tuple[float, Dict[Cut, Utilization], Dict[Cut, Dict[BaseNode, Utilization]]]:
343364
"""
344-
Compute graph activation cuts utilization.
365+
Compute graph activation cuts utilization. If activation quantization configs are provided, then for
366+
quantization preserving nodes, get the previous quantized activation node bit-width.
345367
346368
Args:
347369
target_criterion: criterion to include weights for computation.
@@ -369,7 +391,7 @@ def compute_activation_utilization_by_cut(self,
369391
if not cut_target_nodes:
370392
continue
371393
for n in cut_target_nodes:
372-
qc = act_qcs.get(self.graph.retrieve_preserved_quantization_node(n).name) if act_qcs else None
394+
qc = self._extract_qc(n, act_qcs)
373395
util_per_cut_per_node[cut][n.name] = self.compute_node_activation_tensor_utilization(n, target_criterion,
374396
bitwidth_mode, qc)
375397
util_per_cut[cut] = sum(util_per_cut_per_node[cut].values()) # type: ignore
@@ -384,7 +406,8 @@ def compute_activation_tensors_utilization(self,
384406
include_reused=False) \
385407
-> Tuple[float, Dict[NodeName, Utilization]]:
386408
"""
387-
Compute resource utilization for graph's activations tensors.
409+
Compute resource utilization for graph's activations tensors. If activation quantization configs are provided, then for
410+
quantization preserving nodes, get the previous quantized activation node bit-width.
388411
389412
Args:
390413
target_criterion: criterion to include weights for computation.
@@ -405,7 +428,7 @@ def compute_activation_tensors_utilization(self,
405428

406429
util_per_node: Dict[NodeName, Utilization] = {}
407430
for n in self._topo_sort(nodes):
408-
qc = act_qcs.get(n.name) if act_qcs else None
431+
qc = self._extract_qc(n, act_qcs)
409432
util = self.compute_node_activation_tensor_utilization(n, None, bitwidth_mode, qc)
410433
util_per_node[n.name] = util
411434

@@ -440,7 +463,7 @@ def compute_node_activation_tensor_utilization(self,
440463
return Utilization(0, 0)
441464

442465
size = self._act_tensors_size[n.name]
443-
nbits = self._get_activation_nbits(self.graph.retrieve_preserved_quantization_node(n), bitwidth_mode, qc)
466+
nbits = self._get_activation_nbits(n, bitwidth_mode, qc)
444467
bytes_ = size * nbits / 8
445468
return Utilization(size, bytes_)
446469

0 commit comments

Comments
 (0)