5757)
5858
5959from ...exporters .openvino import main_export
60- from .configuration import OVConfig , OVWeightQuantizationConfig
60+ from .configuration import OVConfig , OVQuantizationMethod , OVWeightQuantizationConfig
6161from .loaders import OVTextualInversionLoaderMixin
6262from .modeling_base import OVBaseModel
6363from .utils import (
6464 ONNX_WEIGHTS_NAME ,
6565 OV_TO_NP_TYPE ,
6666 OV_XML_FILE_NAME ,
67- PREDEFINED_SD_DATASETS ,
6867 _print_compiled_model_properties ,
6968)
7069
@@ -293,35 +292,27 @@ def _from_pretrained(
293292 else :
294293 kwargs [name ] = load_method (new_model_save_dir )
295294
296- quantization_config = cls ._prepare_weight_quantization_config (quantization_config , load_in_8bit )
297-
298295 unet_path = new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name
299- if quantization_config is not None and quantization_config .dataset is not None :
300- # load the UNet model uncompressed to apply hybrid quantization further
301- unet = cls .load_model (unet_path )
302- # Apply weights compression to other `components` without dataset
303- weight_quantization_params = {
304- param : value for param , value in quantization_config .__dict__ .items () if param != "dataset"
305- }
306- weight_quantization_config = OVWeightQuantizationConfig .from_dict (weight_quantization_params )
307- else :
308- weight_quantization_config = quantization_config
309- unet = cls .load_model (unet_path , weight_quantization_config )
310-
311296 components = {
312297 "vae_encoder" : new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name ,
313298 "vae_decoder" : new_model_save_dir / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / vae_decoder_file_name ,
314299 "text_encoder" : new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / text_encoder_file_name ,
315300 "text_encoder_2" : new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER / text_encoder_2_file_name ,
316301 }
317302
318- for key , value in components .items ():
319- components [key ] = cls .load_model (value , weight_quantization_config ) if value .is_file () else None
320-
321303 if model_save_dir is None :
322304 model_save_dir = new_model_save_dir
323305
324- if quantization_config is not None and quantization_config .dataset is not None :
306+ quantization_config = cls ._prepare_weight_quantization_config (quantization_config , load_in_8bit )
307+ if quantization_config is None or quantization_config .dataset is None :
308+ unet = cls .load_model (unet_path , quantization_config )
309+ for key , value in components .items ():
310+ components [key ] = cls .load_model (value , quantization_config ) if value .is_file () else None
311+ else :
312+ # Load uncompressed models to apply hybrid quantization further
313+ unet = cls .load_model (unet_path )
314+ for key , value in components .items ():
315+ components [key ] = cls .load_model (value ) if value .is_file () else None
325316 sd_model = cls (unet = unet , config = config , model_save_dir = model_save_dir , ** components , ** kwargs )
326317
327318 supported_pipelines = (
@@ -332,12 +323,14 @@ def _from_pretrained(
332323 if not isinstance (sd_model , supported_pipelines ):
333324 raise NotImplementedError (f"Quantization in hybrid mode is not supported for { cls .__name__ } " )
334325
335- nsamples = quantization_config .num_samples if quantization_config .num_samples else 200
336- unet_inputs = sd_model ._prepare_unet_inputs (quantization_config .dataset , nsamples )
326+ from optimum .intel import OVQuantizer
337327
338- from .quantization import _hybrid_quantization
328+ hybrid_quantization_config = deepcopy (quantization_config )
329+ hybrid_quantization_config .quant_method = OVQuantizationMethod .HYBRID
330+ quantizer = OVQuantizer (sd_model )
331+ quantizer .quantize (ov_config = OVConfig (quantization_config = hybrid_quantization_config ))
339332
340- unet = _hybrid_quantization ( sd_model . unet . model , weight_quantization_config , dataset = unet_inputs )
333+ return sd_model
341334
342335 return cls (
343336 unet = unet ,
@@ -348,62 +341,6 @@ def _from_pretrained(
348341 ** kwargs ,
349342 )
350343
351- def _prepare_unet_inputs (
352- self ,
353- dataset : Union [str , List [Any ]],
354- num_samples : int ,
355- height : Optional [int ] = None ,
356- width : Optional [int ] = None ,
357- seed : Optional [int ] = 42 ,
358- ** kwargs ,
359- ) -> Dict [str , Any ]:
360- self .compile ()
361-
362- size = self .unet .config .get ("sample_size" , 64 ) * self .vae_scale_factor
363- height = height or min (size , 512 )
364- width = width or min (size , 512 )
365-
366- if isinstance (dataset , str ):
367- dataset = deepcopy (dataset )
368- available_datasets = PREDEFINED_SD_DATASETS .keys ()
369- if dataset not in available_datasets :
370- raise ValueError (
371- f"""You have entered a string value for dataset. You can only choose between
372- { list (available_datasets )} , but the { dataset } was found"""
373- )
374-
375- from datasets import load_dataset
376-
377- dataset_metadata = PREDEFINED_SD_DATASETS [dataset ]
378- dataset = load_dataset (dataset , split = dataset_metadata ["split" ], streaming = True ).shuffle (seed = seed )
379- input_names = dataset_metadata ["inputs" ]
380- dataset = dataset .select_columns (list (input_names .values ()))
381-
382- def transform_fn (data_item ):
383- return {inp_name : data_item [column ] for inp_name , column in input_names .items ()}
384-
385- else :
386-
387- def transform_fn (data_item ):
388- return data_item if isinstance (data_item , (list , dict )) else [data_item ]
389-
390- from .quantization import InferRequestWrapper
391-
392- calibration_data = []
393- self .unet .request = InferRequestWrapper (self .unet .request , calibration_data )
394-
395- for inputs in dataset :
396- inputs = transform_fn (inputs )
397- if isinstance (inputs , dict ):
398- self .__call__ (** inputs , height = height , width = width )
399- else :
400- self .__call__ (* inputs , height = height , width = width )
401- if len (calibration_data ) >= num_samples :
402- break
403-
404- self .unet .request = self .unet .request .request
405- return calibration_data [:num_samples ]
406-
407344 @classmethod
408345 def _from_transformers (
409346 cls ,
0 commit comments