Skip to content

Commit a777853

Browse files
irenabirenab
authored andcommitted
add comments and validation check
1 parent 9da7773 commit a777853

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,13 @@ def __init__(self,
165165
self.original_weights_node = weights_node
166166

167167
v_candidates = []
168-
kernel_attr = fw_info.get_kernel_op_attributes(weights_node.type)[0]
168+
kernel_attrs = fw_info.get_kernel_op_attributes(weights_node.type)
169+
assert len(kernel_attrs) == 1 and kernel_attrs[0] is not None, 'Expected exactly one kernel attr.'
170+
kernel_attr = kernel_attrs[0]
171+
conf_attrs = [attr for attr in weights_node.weights if weights_node.is_configurable_weight(attr)]
172+
if len(conf_attrs) > 1 or len(conf_attrs) == 1 and conf_attrs[0] != kernel_attr: # pragma: no cover
173+
raise NotImplementedError('Only kernel attr can be configurable.')
174+
169175
weights_candidates_quantization_cfg = weights_node.get_unique_weights_candidates(kernel_attr)
170176
for c_a in act_node.candidates_quantization_cfg:
171177
for c_w in weights_candidates_quantization_cfg:
@@ -182,7 +188,6 @@ def __init__(self,
182188
v_candidates.append(composed_candidate)
183189

184190
# sorting the candidates by weights number of bits first and then by activation number of bits (reversed order)
185-
kernel_attr = fw_info.get_kernel_op_attributes(self.type)[0]
186191
v_candidates.sort(key=lambda c: (c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits,
187192
c.activation_quantization_cfg.activation_n_bits), reverse=True)
188193

model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,8 @@ def compute_node_bops(self,
544544
# we don't need the original node (and cannot use it for custom configuration anyway)
545545
a_node = n
546546
else:
547+
# if we are running on the original (non-virtual) graph, we only compute bops if it would be computed in an
548+
# equivalent virtual graph for consistency.
547549
a_node = get_input_activation_if_composable(self.graph, n, warn=False)
548550
if a_node is None:
549551
return 0

0 commit comments

Comments
 (0)