@@ -518,32 +518,20 @@ def post_init(self):
518518 TORCHAO_QUANT_TYPE_METHODS = self ._get_torchao_quant_type_to_method ()
519519 AO_VERSION = self ._get_ao_version ()
520520
521- if isinstance (self .quant_type , str ) and self .quant_type not in TORCHAO_QUANT_TYPE_METHODS .keys ():
522- is_floating_quant_type = self .quant_type .startswith ("float" ) or self .quant_type .startswith ("fp" )
523- if is_floating_quant_type and not self ._is_xpu_or_cuda_capability_atleast_8_9 ():
521+ if isinstance (self .quant_type , str ):
522+ if self .quant_type not in TORCHAO_QUANT_TYPE_METHODS .keys ():
523+ is_floating_quant_type = self .quant_type .startswith ("float" ) or self .quant_type .startswith ("fp" )
524+ if is_floating_quant_type and not self ._is_xpu_or_cuda_capability_atleast_8_9 ():
525+ raise ValueError (
526+ f"Requested quantization type: { self .quant_type } is not supported on GPUs with CUDA capability <= 8.9. You "
527+ f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
528+ )
529+
524530 raise ValueError (
525- f"Requested quantization type: { self .quant_type } is not supported on GPUs with CUDA capability <= 8.9. You "
526- f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()` ."
531+ f"Requested quantization type: { self .quant_type } is not supported or is an incorrect `quant_type` name. If you think the "
532+ f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues ."
527533 )
528534
529- raise ValueError (
530- f"Requested quantization type: { self .quant_type } is not supported or is an incorrect `quant_type` name. If you think the "
531- f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
532- )
533- elif AO_VERSION > version .parse ("0.9.0" ):
534- from torchao .quantization .quant_api import AOBaseConfig
535-
536- if not isinstance (self .quant_type , AOBaseConfig ):
537- raise TypeError (
538- f"`quant_type` must be either a string or an `AOBaseConfig` instance, got { type (self .quant_type )} ."
539- )
540- else :
541- raise ValueError (
542- f"In torchao <= 0.9.0, quant_type must be a string. Got { type (self .quant_type )} . "
543- f"Please upgrade to torchao > 0.9.0 to use `AOBaseConfig` instances."
544- )
545-
546- if isinstance (self .quant_type , str ):
547535 method = TORCHAO_QUANT_TYPE_METHODS [self .quant_type ]
548536 signature = inspect .signature (method )
549537 all_kwargs = {
@@ -558,6 +546,18 @@ def post_init(self):
558546 f'The quantization method "{ self .quant_type } " does not support the following keyword arguments: '
559547 f"{ unsupported_kwargs } . The following keywords arguments are supported: { all_kwargs } ."
560548 )
549+ elif AO_VERSION > version .parse ("0.9.0" ):
550+ from torchao .quantization .quant_api import AOBaseConfig
551+
552+ if not isinstance (self .quant_type , AOBaseConfig ):
553+ raise TypeError (
554+ f"`quant_type` must be either a string or an `AOBaseConfig` instance, got { type (self .quant_type )} ."
555+ )
556+ else :
557+ raise ValueError (
558+ f"In torchao <= 0.9.0, quant_type must be a string. Got { type (self .quant_type )} . "
559+ f"Please upgrade to torchao > 0.9.0 to use `AOBaseConfig` instances."
560+ )
561561
562562 def to_dict (self ):
563563 """Convert configuration to a dictionary."""
0 commit comments