Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ class BitwidthMode(Enum):
single-precision nodes. To compute custom single precision configuration, use QCustom.
"""
Float = auto()
Q8Bit = auto()
QMaxBit = auto()
QMinBit = auto()
QCustom = auto()
Expand Down Expand Up @@ -573,7 +572,7 @@ def compute_node_bops(self,
not (a_node.is_activation_quantization_enabled() or n.is_weights_quantization_enabled(kernel_attr))):
return 0

act_qc = act_qcs.get(a_node.name) if act_qcs else None
act_qc = self._extract_qc(a_node, act_qcs)
a_nbits = self._get_activation_nbits(a_node, bitwidth_mode, act_qc)
w_nbits = self._get_weight_nbits(n, kernel_attr, bitwidth_mode, w_qc)
node_bops = a_nbits * w_nbits * node_mac
Expand Down Expand Up @@ -708,23 +707,20 @@ def _get_activation_nbits(self,
Returns:
Activation bit-width.
"""
n = self.graph.retrieve_preserved_quantization_node(n)
if act_qc:
assert bitwidth_mode == BitwidthMode.QCustom
return act_qc.activation_n_bits if act_qc.quant_mode == ActivationQuantizationMode.QUANT else FLOAT_BITWIDTH

if bitwidth_mode == BitwidthMode.Float or not (n.is_activation_quantization_enabled() or
n.is_quantization_preserving()):
if bitwidth_mode == BitwidthMode.Float or not n.is_activation_quantization_enabled():
return FLOAT_BITWIDTH

if bitwidth_mode == BitwidthMode.Q8Bit:
return 8

if bitwidth_mode in self._bitwidth_mode_fn:
candidates_nbits = [c.activation_quantization_cfg.activation_n_bits for c in n.candidates_quantization_cfg]
return self._bitwidth_mode_fn[bitwidth_mode](candidates_nbits)

if bitwidth_mode in [BitwidthMode.QCustom, BitwidthMode.QDefaultSP]:
qcs = self.graph.retrieve_preserved_quantization_node(n).get_unique_activation_candidates()
qcs = n.get_unique_activation_candidates()
if len(qcs) != 1:
raise ValueError(f'Could not retrieve the activation quantization candidate for node {n} '
f'as it has {len(qcs)}!=1 unique candidates.')
Expand Down Expand Up @@ -760,9 +756,6 @@ def _get_weight_nbits(cls,
if bitwidth_mode == BitwidthMode.Float or not n.is_weights_quantization_enabled(w_attr):
return FLOAT_BITWIDTH

if bitwidth_mode == BitwidthMode.Q8Bit:
return 8

node_qcs = n.get_unique_weights_candidates(w_attr)
w_qcs = [qc.weights_quantization_cfg.get_attr_config(w_attr) for qc in node_qcs]
if bitwidth_mode in cls._bitwidth_mode_fn:
Expand Down
1 change: 1 addition & 0 deletions tests_pytest/_test_util/graph_builder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def build_nbits_qc(a_nbits=8, a_enable=True, w_attr=None, pos_attr=(32, False, (
Final name can be passed along with convert_canonical_attr=False.
pos_attr: quantization configuration for positional weights in format (nbits, q enabled, indices).
convert_canonical_attr: whether to convert w_attr keys to full names.
q_preserving: Whether node is quantization preserving.

Returns:

Expand Down
Loading