@@ -460,19 +460,18 @@ def get_attr_config(self, attr_name: Union[str, int]) -> WeightsAttrQuantization
460460 if attr_name is None : # pragma: no cover
461461 Logger .critical ("Got 'None' attribute name for retrieving weights attribute quantization configuration." )
462462
463- if isinstance (attr_name , int ):
463+ attrs_with_name = self ._extract_config_for_attributes_with_name (attr_name )
464+ attr_cfg = None
465+ if len (attrs_with_name ) == 0 and isinstance (attr_name , int ):
464466 # this is a positional attribute
465467 attr_cfg = self .pos_attributes_config_mapping .get (attr_name )
466- else :
467- attrs_with_name = self ._extract_config_for_attributes_with_name (attr_name )
468- attr_cfg = None
469- if len (attrs_with_name ) == 1 :
470- attr_cfg = [v for v in attrs_with_name .values ()][0 ]
471- elif len (attrs_with_name ) > 1 :
472- Logger .warning (f"Found multiple weight attributes containing the name { attr_name } : "
473- f"{ list (attrs_with_name .keys ())} . Looking for an attributes with the exact name." )
474- # If no attribute with the exact name then an error would be thrown
475- attr_cfg = self .attributes_config_mapping .get (attr_name )
468+ if len (attrs_with_name ) == 1 :
469+ attr_cfg = [v for v in attrs_with_name .values ()][0 ]
470+ elif len (attrs_with_name ) > 1 :
471+ Logger .warning (f"Found multiple weight attributes containing the name { attr_name } : "
472+ f"{ list (attrs_with_name .keys ())} . Looking for an attributes with the exact name." )
473+ # If no attribute with the exact name then an error would be thrown
474+ attr_cfg = self .attributes_config_mapping .get (attr_name )
476475
477476 if attr_cfg is None : # pragma: no cover
478477 Logger .critical (f"Weight attribute '{ attr_name } ' config could not be found." )
@@ -533,7 +532,7 @@ def _extract_config_for_attributes_with_name(self, attr_name) -> Dict[str, Weigh
533532 Returns: A mapping between attributes that contain the given name to their configuration.
534533
535534 """
536- attrs_with_name = {k : v for k , v in self .attributes_config_mapping .items () if attr_name in k }
535+ attrs_with_name = {k : v for k , v in self .attributes_config_mapping .items () if str ( attr_name ) in str ( k ) }
537536 if len (attrs_with_name ) > 1 :
538537 Logger .warning (f"Found multiple weight attributes containing the name { attr_name } : "
539538 f"{ list (attrs_with_name .keys ())} ." )
0 commit comments