2222import torch
2323from huggingface_hub import hf_hub_download
2424from huggingface_hub .constants import HUGGINGFACE_HUB_CACHE
25+ from huggingface_hub .utils import EntryNotFoundError
2526from neural_compressor .utils .pytorch import load
2627from transformers import (
2728 AutoConfig ,
4041)
4142from transformers .modeling_utils import no_init_weights
4243from transformers .models .auto .auto_factory import _get_model_class
44+ from transformers .utils import SAFE_WEIGHTS_NAME , WEIGHTS_NAME
4345from transformers .utils .generic import ContextManagers
4446
4547from optimum .intel .generation import BaseModelForCausalLM
4648
4749from ...modeling_base import OptimizedModel
4850from ..utils .import_utils import _torch_version , is_itrex_available , is_torch_version
4951from .configuration import INCConfig
50- from .utils import WEIGHTS_NAME
52+ from .utils import QUANTIZATION_CONFIG_NAME
5153
5254
5355logger = logging .getLogger (__name__ )
@@ -119,33 +121,70 @@ def _from_pretrained(
119121 raise ValueError ("You cannot use both `use_auth_token` and `token` arguments at the same time." )
120122 token = use_auth_token
121123
122- model_name_or_path = kwargs .pop ("model_name_or_path" , None )
123- if model_name_or_path is not None :
124- logger .warning ("`model_name_or_path` is deprecated please use `model_id`" )
125- model_id = model_id or model_name_or_path
126-
127124 model_path = Path (model_id )
128-
129- if model_path .is_dir ():
130- model_cache_path = model_path / file_name
125+ is_local = model_path .is_dir ()
126+ model_cache_path = None
127+ inc_config = None
128+ msg = None
129+ if is_local :
130+ if (model_path / subfolder / SAFE_WEIGHTS_NAME ).is_file ():
131+ file_name = SAFE_WEIGHTS_NAME
132+ elif not (model_path / subfolder / file_name ).is_file ():
133+ raise EnvironmentError (
134+ f"Error no file named { SAFE_WEIGHTS_NAME } or { file_name } found in directory { model_path / subfolder } "
135+ )
136+ model_cache_path = model_path / subfolder / file_name
131137 else :
132- model_cache_path = hf_hub_download (
133- repo_id = model_id ,
134- filename = file_name ,
135- subfolder = subfolder ,
136- token = token ,
137- revision = revision ,
138- cache_dir = cache_dir ,
139- force_download = force_download ,
140- local_files_only = local_files_only ,
141- )
138+ # Try download safetensors if exist
139+ try :
140+ model_cache_path = hf_hub_download (
141+ repo_id = model_id ,
142+ filename = SAFE_WEIGHTS_NAME ,
143+ subfolder = subfolder ,
144+ token = token ,
145+ revision = revision ,
146+ cache_dir = cache_dir ,
147+ force_download = force_download ,
148+ local_files_only = local_files_only ,
149+ )
150+ except EntryNotFoundError :
151+ pass
152+
153+ if model_cache_path is None :
154+ model_cache_path = hf_hub_download (
155+ repo_id = model_id ,
156+ filename = file_name ,
157+ subfolder = subfolder ,
158+ token = token ,
159+ revision = revision ,
160+ cache_dir = cache_dir ,
161+ force_download = force_download ,
162+ local_files_only = local_files_only ,
163+ )
142164
143165 model_save_dir = Path (model_cache_path ).parent
144- inc_config = None
145- msg = None
166+
146167 if is_itrex_available ():
147- try :
148- quantization_config = PretrainedConfig .from_pretrained (model_save_dir / "quantize_config.json" )
168+ quantization_config_path = None
169+ if is_local :
170+ quantization_config_path = model_path / subfolder / QUANTIZATION_CONFIG_NAME
171+ else :
172+ try :
173+ quantization_config_path = hf_hub_download (
174+ repo_id = model_id ,
175+ filename = QUANTIZATION_CONFIG_NAME ,
176+ subfolder = subfolder ,
177+ token = token ,
178+ revision = revision ,
179+ cache_dir = cache_dir ,
180+ force_download = force_download ,
181+ local_files_only = local_files_only ,
182+ )
183+ except EntryNotFoundError :
184+ pass
185+
186+ if quantization_config_path and Path (quantization_config_path ).is_file ():
187+ quantization_config = PretrainedConfig .from_pretrained (quantization_config_path )
149188 algorithm = getattr (quantization_config , "quant_method" , None )
150189 if algorithm in {"rtn" , "gptq" , "awq" , "autoround" }:
151190 from intel_extension_for_transformers .transformers .modeling .modeling_auto import (
@@ -154,7 +193,7 @@ def _from_pretrained(
154193
155194 _BaseQBitsAutoModelClass .ORIG_MODEL = cls .auto_model_class
156195
157- return _BaseQBitsAutoModelClass .from_pretrained (
196+ model = _BaseQBitsAutoModelClass .from_pretrained (
158197 pretrained_model_name_or_path = model_id ,
159198 token = token ,
160199 revision = revision ,
@@ -163,12 +202,16 @@ def _from_pretrained(
163202 local_files_only = local_files_only ,
164203 subfolder = subfolder ,
165204 trust_remote_code = trust_remote_code ,
205+ use_neural_speed = False ,
166206 ** kwargs ,
167207 )
168- except EnvironmentError :
169- msg = "The model is not quantized with weight-only quantization."
208+
209+ return cls (
210+ model , config = config , model_save_dir = model_save_dir , q_config = quantization_config , ** kwargs
211+ )
212+
170213 try :
171- inc_config = INCConfig .from_pretrained (model_id )
214+ inc_config = INCConfig .from_pretrained (model_id , subfolder = subfolder , revision = revision )
172215 if not is_torch_version ("==" , inc_config .torch_version ):
173216 msg = f"Quantized model was obtained with torch version { inc_config .torch_version } but { _torch_version } was found."
174217 logger .warning (f"{ msg } " )
@@ -209,15 +252,19 @@ def _from_pretrained(
209252 )
210253
211254 def _save_pretrained (self , save_directory : Union [str , Path ]):
212- output_path = os .path .join (save_directory , WEIGHTS_NAME )
213-
214255 if isinstance (self .model , torch .nn .Module ):
215- state_dict = self .model .state_dict ()
216- if self ._q_config :
217- state_dict ["best_configure" ] = self ._q_config
218- torch .save (state_dict , output_path )
256+ # For ITREX model
257+ if isinstance (self ._q_config , PretrainedConfig ):
258+ self ._q_config .to_json_file (os .path .join (save_directory , QUANTIZATION_CONFIG_NAME ))
259+ self .model .save_pretrained (save_directory )
260+ # For INC model the state dictionary needs to be modified to include the quantization parameters
261+ else :
262+ state_dict = self .model .state_dict ()
263+ if isinstance (self ._q_config , dict ):
264+ state_dict ["best_configure" ] = self ._q_config
265+ torch .save (state_dict , os .path .join (save_directory , WEIGHTS_NAME ))
219266 else :
220- torch .jit .save (self .model , output_path )
267+ torch .jit .save (self .model , os . path . join ( save_directory , WEIGHTS_NAME ) )
221268
222269 if self .inc_config :
223270 self .inc_config .save_pretrained (save_directory )
0 commit comments