Skip to content

Commit 31bae84

Browse files
irenabirenab
authored andcommitted
fix virtual node with duplicate weights attributes in activation and weight nodes
1 parent 20c321c commit 31bae84

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15+
import uuid
1516

1617
from typing import Dict, Any, Tuple
1718

@@ -146,16 +147,26 @@ def __init__(self,
146147
raise NotImplementedError('Only kernel weight can be configurable.') # pragma: no cover
147148

148149
weights = weights_node.weights
150+
act_node_w_rename = {}
149151
if act_node.weights:
150152
assert fw_info.get_kernel_op_attributes(act_node)[0] is None, \
151153
f'Node {act_node} with kernel cannot be used as activation for VirtualActivationWeightsNode.'
152-
if set(weights_node.weights.keys()).intersection(set(act_node.weights.keys())):
153-
raise ValueError('Activation and weight nodes are not expected to have the same weight attribute') # pragma: no cover
154154
if act_node.has_any_configurable_weight():
155155
raise NotImplementedError('Node with a configurable weight cannot be used as activation for '
156156
'VirtualActivationWeightsNode.') # pragma: no cover
157157
# combine weights from activation and weights
158-
weights.update(act_node.weights)
158+
for w_id, w in act_node.weights.items():
159+
if w_id not in weights and not (isinstance(w_id, str) and kernel_attr in w_id):
160+
weights[w_id] = w
161+
continue
162+
# if same identifier is used as in weight nodes (or contains the kernel substring), generate a new
163+
# unique id. If positional, generate a new (and clearly made up) index.
164+
# This only serves for resource utilization computation so in theory this shouldn't matter, as long as
165+
# quantization config dict keys are updated accordingly.
166+
uniq_id = uuid.uuid4().hex[:8] if isinstance(w_id, str) else (100 + w_id)
167+
assert uniq_id not in weights
168+
act_node_w_rename[w_id] = uniq_id
169+
weights[uniq_id] = w
159170

160171
name = f"{VIRTUAL_ACTIVATION_WEIGHTS_NODE_PREFIX}_{act_node.name}_{weights_node.name}"
161172
super().__init__(name,
@@ -181,10 +192,12 @@ def __init__(self,
181192
if act_node.weights:
182193
# add non-kernel weights cfg from activation node to the composed node's weights cfg
183194
composed_candidate.weights_quantization_cfg.attributes_config_mapping.update(
184-
c_a.weights_quantization_cfg.attributes_config_mapping
195+
{act_node_w_rename.get(k, k): v
196+
for k, v in c_a.weights_quantization_cfg.attributes_config_mapping.items()}
185197
)
186198
composed_candidate.weights_quantization_cfg.pos_attributes_config_mapping.update(
187-
c_a.weights_quantization_cfg.pos_attributes_config_mapping
199+
{act_node_w_rename.get(k, k): v
200+
for k, v in c_a.weights_quantization_cfg.pos_attributes_config_mapping.items()}
188201
)
189202
v_candidates.append(composed_candidate)
190203

0 commit comments

Comments
 (0)