Skip to content

Commit df315c9

Browse files
committed
Fix get_attr_config
1 parent 766d41f commit df315c9

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

model_compression_toolkit/core/common/quantization/node_quantization_config.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -460,19 +460,18 @@ def get_attr_config(self, attr_name: Union[str, int]) -> WeightsAttrQuantization
460460
if attr_name is None: # pragma: no cover
461461
Logger.critical("Got 'None' attribute name for retrieving weights attribute quantization configuration.")
462462

463-
if isinstance(attr_name, int):
463+
attrs_with_name = self._extract_config_for_attributes_with_name(attr_name)
464+
attr_cfg = None
465+
if len(attrs_with_name) == 0 and isinstance(attr_name, int):
464466
# this is a positional attribute
465467
attr_cfg = self.pos_attributes_config_mapping.get(attr_name)
466-
else:
467-
attrs_with_name = self._extract_config_for_attributes_with_name(attr_name)
468-
attr_cfg = None
469-
if len(attrs_with_name) == 1:
470-
attr_cfg = [v for v in attrs_with_name.values()][0]
471-
elif len(attrs_with_name) > 1:
472-
Logger.warning(f"Found multiple weight attributes containing the name {attr_name}: "
473-
f"{list(attrs_with_name.keys())}. Looking for an attributes with the exact name.")
474-
# If no attribute with the exact name then an error would be thrown
475-
attr_cfg = self.attributes_config_mapping.get(attr_name)
468+
if len(attrs_with_name) == 1:
469+
attr_cfg = [v for v in attrs_with_name.values()][0]
470+
elif len(attrs_with_name) > 1:
471+
Logger.warning(f"Found multiple weight attributes containing the name {attr_name}: "
472+
f"{list(attrs_with_name.keys())}. Looking for an attributes with the exact name.")
473+
# If no attribute with the exact name then an error would be thrown
474+
attr_cfg = self.attributes_config_mapping.get(attr_name)
476475

477476
if attr_cfg is None: # pragma: no cover
478477
Logger.critical(f"Weight attribute '{attr_name}' config could not be found.")
@@ -533,7 +532,7 @@ def _extract_config_for_attributes_with_name(self, attr_name) -> Dict[str, Weigh
533532
Returns: A mapping between attributes that contain the given name to their configuration.
534533
535534
"""
536-
attrs_with_name = {k: v for k, v in self.attributes_config_mapping.items() if attr_name in k}
535+
attrs_with_name = {k: v for k, v in self.attributes_config_mapping.items() if str(attr_name) in str(k)}
537536
if len(attrs_with_name) > 1:
538537
Logger.warning(f"Found multiple weight attributes containing the name {attr_name}: "
539538
f"{list(attrs_with_name.keys())}.")

0 commit comments

Comments
 (0)