3434
3535from packaging import version
3636
37- from ..utils import is_torch_available , is_torchao_available , logging
37+ from ..utils import is_torch_available , is_torchao_available , is_torchao_version , logging
3838
3939
4040if is_torch_available ():
@@ -516,7 +516,6 @@ def __init__(
516516
517517 def post_init (self ):
518518 TORCHAO_QUANT_TYPE_METHODS = self ._get_torchao_quant_type_to_method ()
519- AO_VERSION = self ._get_ao_version ()
520519
521520 if isinstance (self .quant_type , str ):
522521 if self .quant_type not in TORCHAO_QUANT_TYPE_METHODS .keys ():
@@ -546,7 +545,7 @@ def post_init(self):
546545 f'The quantization method "{ self .quant_type } " does not support the following keyword arguments: '
547546 f"{ unsupported_kwargs } . The following keywords arguments are supported: { all_kwargs } ."
548547 )
549- elif AO_VERSION > version . parse ( "0.9.0" ):
548+ elif is_torchao_version ( ">" , "0.9.0" ):
550549 from torchao .quantization .quant_api import AOBaseConfig
551550
552551 if not isinstance (self .quant_type , AOBaseConfig ):
@@ -590,8 +589,8 @@ def to_dict(self):
590589 @classmethod
591590 def from_dict (cls , config_dict , return_unused_kwargs = False , ** kwargs ):
592591 """Create configuration from a dictionary."""
593- ao_version = cls . _get_ao_version ()
594- assert ao_version > version . parse ( "0.9.0" ), " TorchAoConfig requires torchao > 0.9.0 for construction from dict"
592+ if not is_torchao_version ( ">" , "0.9.0" ):
593+ raise NotImplementedError ( " TorchAoConfig requires torchao > 0.9.0 for construction from dict")
595594 config_dict = config_dict .copy ()
596595 quant_type = config_dict .pop ("quant_type" )
597596
@@ -611,14 +610,6 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
611610
612611 return cls (quant_type = quant_type , ** config_dict )
613612
614- @staticmethod
615- def _get_ao_version () -> version .Version :
616- """Centralized check for TorchAO availability and version requirements."""
617- if not is_torchao_available ():
618- raise ValueError ("TorchAoConfig requires torchao to be installed. Install with `pip install torchao`" )
619-
620- return version .parse (importlib .metadata .version ("torchao" ))
621-
622613 @classmethod
623614 def _get_torchao_quant_type_to_method (cls ):
624615 r"""
0 commit comments