diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index a88dcbc2..23616c20 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -169,7 +169,7 @@ def from_pretrained_model( cls, model: Module, sparsity_config: Union[SparsityCompressionConfig, str, None] = None, - quantization_format: Optional[Union[str, List[str]]] = None, + quantization_format: Optional[Union[str, CompressionFormat, List[str], List[CompressionFormat]]] = None, ) -> Optional["ModelCompressor"]: """ Given a pytorch model and optional sparsity and/or quantization configs, @@ -203,7 +203,7 @@ def from_pretrained_model( quantization_config=quantization_config, transform_config=transform_config, compression_formats=[quantization_format] - if isinstance(quantization_format, str) + if not isinstance(quantization_format, list) else quantization_format, ) @@ -315,10 +315,11 @@ def __init__( self.quantization_compressor = {} for format in self.compression_formats: + name = format.value if isinstance(format, CompressionFormat) else format self.quantization_compressor[ format ] = BaseCompressor.load_from_registry( - format, config=quantization_config + name, config=quantization_config ) # ----- used by hf quantizer ----- #