@@ -143,6 +143,8 @@ def __init__(
143143 self ._supports_sdpa = getattr (model , "_supports_sdpa" , None )
144144 self ._supports_quantized_cache = getattr (model , "_supports_quantized_cache" , None )
145145 self ._supports_static_cache = getattr (model , "_supports_static_cache" , None )
146+ self ._can_compile_fullgraph = getattr (model , "_can_compile_fullgraph" , False )
147+ self ._tp_size = getattr (model , "_tp_size" , None )
146148 self ._dtype = self .model .dtype if self .model .dtype is not None else torch .float32
147149 self .use_cache = kwargs .get ("use_cache" , False )
148150 self .model_save_dir = model_save_dir
@@ -182,7 +184,7 @@ def from_pretrained(
182184 model = cls .auto_model_class .from_pretrained (model_id , ** kwargs )
183185 if getattr (model .config , "torchscript" , False ):
184186 raise ValueError ("IPEXModel is no longer support torchscript models." )
185- return cls (model , config = kwargs .pop ("config" , model .config ), ** kwargs )
187+ return cls (model , config = kwargs .pop ("config" , model .config ), model_save_dir = model_id , ** kwargs )
186188
187189 def _save_pretrained (self , save_directory : Union [str , Path ]):
188190 self .model .save_pretrained (save_directory , safe_serialization = False )
@@ -207,6 +209,13 @@ def device(self) -> torch.device:
207209 def dtype (self ) -> torch .dtype :
208210 return self ._dtype
209211
212+ @property
213+ def tp_size (self ):
214+ """
215+ Returns the model's tensor parallelism degree.
216+ """
217+ return self ._tp_size
218+
210219 @property
211220 def model_dtype (self ):
212221 logger .warning (
@@ -311,10 +320,15 @@ def __init__(
311320
312321 self .generation_config = GenerationConfig .from_model_config (self .config )
313322 try :
314- self .model_cls = get_class_from_dynamic_module (
315- self .config .auto_map ["AutoModelForCausalLM" ], model_save_dir
316- )
317- except AttributeError :
323+ # Use model_save_dir if available, otherwise use config's name_or_path
324+ pretrained_model_name_or_path = model_save_dir or getattr (self .config , "_name_or_path" , None )
325+ if pretrained_model_name_or_path is not None and hasattr (self .config , "auto_map" ):
326+ self .model_cls = get_class_from_dynamic_module (
327+ self .config .auto_map ["AutoModelForCausalLM" ], pretrained_model_name_or_path
328+ )
329+ else :
330+ self .model_cls = get_model_class (self .config , AutoModelForCausalLM ._model_mapping )
331+ except (AttributeError , KeyError ):
318332 self .model_cls = get_model_class (self .config , AutoModelForCausalLM ._model_mapping )
319333
320334 if hasattr (self .model_cls , "_convert_to_standard_cache" ):
@@ -497,10 +511,15 @@ def __init__(
497511
498512 self .generation_config = GenerationConfig .from_model_config (self .config )
499513 try :
500- self .model_cls = get_class_from_dynamic_module (
501- self .config .auto_map ["AutoModelForSeq2SeqLM" ], model_save_dir
502- )
503- except AttributeError :
514+ # Use model_save_dir if available, otherwise use config's name_or_path
515+ pretrained_model_name_or_path = model_save_dir or getattr (self .config , "_name_or_path" , None )
516+ if pretrained_model_name_or_path is not None and hasattr (self .config , "auto_map" ):
517+ self .model_cls = get_class_from_dynamic_module (
518+ self .config .auto_map ["AutoModelForSeq2SeqLM" ], pretrained_model_name_or_path
519+ )
520+ else :
521+ self .model_cls = get_model_class (self .config , AutoModelForSeq2SeqLM ._model_mapping )
522+ except (AttributeError , KeyError ):
504523 self .model_cls = get_model_class (self .config , AutoModelForSeq2SeqLM ._model_mapping )
505524
506525 if hasattr (self .model_cls , "_convert_to_standard_cache" ):
0 commit comments