1818import types
1919import warnings
2020from enum import Enum
21- from itertools import chain
2221from pathlib import Path
2322from typing import Callable , Optional , Union
2423
2524import torch
2625from datasets import Dataset , load_dataset
2726from neural_compressor .config import PostTrainingQuantConfig
28- from neural_compressor .model .onnx_model import ONNXModel
2927from neural_compressor .model .torch_model import IPEXModel , PyTorchModel
3028from neural_compressor .quantization import fit
3129from packaging .version import parse
3836)
3937
4038from optimum .exporters import TasksManager
41- from optimum .exporters .onnx import OnnxConfig
42- from optimum .onnxruntime import ORTModel
43- from optimum .onnxruntime .modeling_decoder import ORTModelForCausalLM
44- from optimum .onnxruntime .modeling_seq2seq import ORTModelForConditionalGeneration
45- from optimum .onnxruntime .utils import ONNX_DECODER_NAME
4639from optimum .quantization_base import OptimumQuantizer
4740
48- from ..utils .constant import _TASK_ALIASES , MIN_QDQ_ONNX_OPSET , ONNX_WEIGHTS_NAME , WEIGHTS_NAME
41+ from ..utils .constant import _TASK_ALIASES , WEIGHTS_NAME
4942from ..utils .import_utils import (
5043 ITREX_IMPORT_ERROR ,
5144 _ipex_version ,
7972)
8073
8174
82- if is_neural_compressor_version ("<" , "2.6" ):
83- from neural_compressor .experimental .export import torch_to_int8_onnx
84- else :
85- from neural_compressor .utils .export import torch_to_int8_onnx
86-
87-
8875if is_itrex_available ():
8976 if is_itrex_version ("<" , ITREX_MINIMUM_VERSION ):
9077 raise ImportError (
@@ -129,7 +116,7 @@ class INCQuantizer(OptimumQuantizer):
129116
130117 def __init__ (
131118 self ,
132- model : torch .nn .Module ,
119+ model : Union [ PreTrainedModel , torch .nn .Module ] ,
133120 eval_fn : Optional [Callable [[PreTrainedModel ], int ]] = None ,
134121 calibration_fn : Optional [Callable [[PreTrainedModel ], int ]] = None ,
135122 task : Optional [str ] = None ,
@@ -143,7 +130,7 @@ def __init__(
143130 The evaluation function to use for the accuracy driven strategy of the quantization process.
144131 The accuracy driven strategy will be enabled only if `eval_fn` is provided.
145132 task (`str`, defaults to None):
146- The task defining the model topology used for the ONNX export .
133+ The task defining the model topology. Will try to infer it from model if not provided .
147134 seed (`int`, defaults to 42):
148135 The random seed to use when shuffling the calibration dataset.
149136 """
@@ -194,23 +181,11 @@ def quantize(
194181 """
195182 save_directory = Path (save_directory )
196183 save_directory .mkdir (parents = True , exist_ok = True )
197- save_onnx_model = kwargs .pop ("save_onnx_model" , False )
198184 device = kwargs .pop ("device" , "cpu" )
199185 use_cpu = device == torch .device ("cpu" ) or device == "cpu"
200186 use_xpu = device == torch .device ("xpu" ) or device == "xpu"
201187 calibration_dataloader = None
202-
203- if save_onnx_model :
204- logger .warning ("ONNX model export is deprecated and will be removed soon." )
205-
206- if isinstance (self ._original_model , ORTModel ):
207- logger .warning ("ONNX model quantization is deprecated and will be removed soon." )
208-
209- if save_onnx_model and isinstance (self ._original_model , ORTModel ):
210- logger .warning ("The model provided is already an ONNX model. Setting `save_onnx_model` to False." )
211- save_onnx_model = False
212-
213- default_name = WEIGHTS_NAME if not isinstance (self ._original_model , ORTModel ) else ONNX_WEIGHTS_NAME
188+ default_name = WEIGHTS_NAME
214189 self ._set_task ()
215190
216191 if kwargs .pop ("weight_only" , None ) is not None :
@@ -229,17 +204,6 @@ def quantize(
229204 f"but only version { IPEX_MINIMUM_VERSION } or higher is supported."
230205 )
231206
232- if save_onnx_model and (
233- not isinstance (quantization_config , PostTrainingQuantConfig )
234- or INCQuantizationMode (quantization_config .approach ) == INCQuantizationMode .DYNAMIC
235- ):
236- logger .warning (
237- "ONNX export for dynamic and weight only quantized model is not supported. "
238- "Only static quantization model can be exported to ONNX format. "
239- "Setting `save_onnx_model` to False."
240- )
241- save_onnx_model = False
242-
243207 # ITREX Weight Only Quantization
244208 if not isinstance (quantization_config , PostTrainingQuantConfig ):
245209 if is_itrex_version ("==" , "1.4.2" ) and (
@@ -306,14 +270,6 @@ def quantize(
306270 data_collator = data_collator ,
307271 )
308272
309- op_type_dict = getattr (quantization_config , "op_type_dict" , None )
310- if save_onnx_model and (op_type_dict is None or "Embedding" not in op_type_dict ):
311- logger .warning (
312- "ONNX export is no supported for model with quantized embeddings. "
313- "Setting `save_onnx_model` to False."
314- )
315- save_onnx_model = False
316-
317273 if not isinstance (quantization_config , PostTrainingQuantConfig ):
318274 if use_cpu :
319275 # will remove after intel-extension-for-transformers 1.3.3 release.
@@ -336,26 +292,8 @@ def quantize(
336292 if isinstance (self ._original_model .config , PretrainedConfig ):
337293 self ._original_model .config .backend = quantization_config .backend
338294
339- if isinstance (self ._original_model , ORTModel ):
340- # TODO : enable seq2seq models
341- if isinstance (self ._original_model , ORTModelForConditionalGeneration ):
342- raise RuntimeError ("ORTModelForConditionalGeneration not supported for quantization" )
343-
344- if isinstance (self ._original_model , ORTModelForCausalLM ):
345- model_or_path = self ._original_model .onnx_paths
346- if len (model_or_path ) > 1 :
347- raise RuntimeError (
348- f"Too many ONNX model files were found in { self ._original_model .onnx_paths } , only `use_cache=False` is supported"
349- )
350- model_or_path = str (model_or_path [0 ])
351- default_name = ONNX_DECODER_NAME
352- else :
353- model_or_path = str (self ._original_model .model_path )
354- else :
355- model_or_path = self ._original_model
356-
357295 compressed_model = fit (
358- model_or_path ,
296+ self . _original_model ,
359297 conf = quantization_config ,
360298 calib_dataloader = calibration_dataloader ,
361299 eval_func = self .eval_fn ,
@@ -373,40 +311,20 @@ def quantize(
373311 if isinstance (compressed_model , IPEXModel ):
374312 model_config .torchscript = True
375313 model_config .backend = "ipex"
376- elif not isinstance (compressed_model , ONNXModel ):
377- compressed_model ._model .config = model_config
378314 model_config .save_pretrained (save_directory )
379315
380316 self ._quantized_model = compressed_model ._model
381317
382- if save_onnx_model :
383- model_type = self ._original_model .config .model_type .replace ("_" , "-" )
384- model_name = getattr (self ._original_model , "name" , None )
385- onnx_config_class = TasksManager .get_exporter_config_constructor (
386- exporter = "onnx" ,
387- model = self ._original_model ,
388- task = self .task ,
389- model_type = model_type ,
390- model_name = model_name ,
391- )
392- onnx_config = onnx_config_class (self ._original_model .config )
393- compressed_model .eval ()
394- output_onnx_path = save_directory .joinpath (ONNX_WEIGHTS_NAME )
395- # Export the compressed model to the ONNX format
396- self ._onnx_export (compressed_model , onnx_config , output_onnx_path )
397-
398318 output_path = save_directory .joinpath (file_name or default_name )
399319 # Save the quantized model
400320 self ._save_pretrained (compressed_model , output_path )
401- quantization_config = INCConfig (quantization = quantization_config , save_onnx_model = save_onnx_model )
321+ quantization_config = INCConfig (quantization = quantization_config )
402322 quantization_config .save_pretrained (save_directory )
403323
404324 @staticmethod
405325 def _save_pretrained (model : Union [PyTorchModel , IPEXModel ], output_path : str ):
406326 if isinstance (model , IPEXModel ):
407327 model ._model .save (output_path )
408- elif isinstance (model , ONNXModel ):
409- model .save (output_path )
410328 else :
411329 state_dict = model ._model .state_dict ()
412330 if hasattr (model , "q_config" ):
@@ -415,46 +333,12 @@ def _save_pretrained(model: Union[PyTorchModel, IPEXModel], output_path: str):
415333
416334 logger .info (f"Model weights saved to { output_path } " )
417335
418- def _onnx_export (
419- self ,
420- model : PyTorchModel ,
421- config : OnnxConfig ,
422- output_path : Union [str , Path ],
423- ):
424- opset = max (config .DEFAULT_ONNX_OPSET , MIN_QDQ_ONNX_OPSET )
425- dynamic_axes = dict (chain (config .inputs .items (), config .outputs .items ()))
426- inputs = config .generate_dummy_inputs (framework = "pt" )
427- device = model .model .device
428- inputs = {k : v .to (device ) for k , v in inputs .items ()}
429-
430- if is_neural_compressor_version (">" , "2.2.1" ):
431- torch_to_int8_onnx (
432- self ._original_model ,
433- model .model ,
434- q_config = model .q_config ,
435- save_path = str (output_path ),
436- example_inputs = inputs ,
437- opset_version = opset ,
438- dynamic_axes = dynamic_axes ,
439- input_names = list (config .inputs .keys ()),
440- output_names = list (config .outputs .keys ()),
441- )
442- else :
443- torch_to_int8_onnx (
444- model .model ,
445- q_config = model .q_config ,
446- save_path = str (output_path ),
447- example_inputs = inputs ,
448- opset_version = opset ,
449- dynamic_axes = dynamic_axes ,
450- input_names = list (config .inputs .keys ()),
451- output_names = list (config .outputs .keys ()),
452- )
453-
454336 def _set_task (self ):
455337 if self .task is None :
456338 try :
457- self .task = TasksManager .infer_task_from_model (self ._original_model .config ._name_or_path )
339+ # using the actual model has better chances of success
340+ # since using the model path does not work with local models
341+ self .task = TasksManager .infer_task_from_model (self ._original_model )
458342 except Exception as e :
459343 self .task = "default"
460344 logger .warning (
0 commit comments