Skip to content

Commit 72e08c4

Browse files
authored
Fix some minor bugs after upgrading transformers to 4.55 for IPEX (#1485)
* fix bug for 'IPEXModelForCausalLM' object has no attribute '_can_compile_fullgraph' Signed-off-by: Liu, Kaixuan <[email protected]> * update Signed-off-by: Liu, Kaixuan <[email protected]> * fix bug for model load Signed-off-by: Liu, Kaixuan <[email protected]> * pass model_save_dir with value 'model_id' Signed-off-by: Liu, Kaixuan <[email protected]> --------- Signed-off-by: Liu, Kaixuan <[email protected]>
1 parent e40ca5d commit 72e08c4

File tree

1 file changed

+28
-9
lines changed

1 file changed

+28
-9
lines changed

optimum/intel/ipex/modeling_base.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ def __init__(
143143
self._supports_sdpa = getattr(model, "_supports_sdpa", None)
144144
self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None)
145145
self._supports_static_cache = getattr(model, "_supports_static_cache", None)
146+
self._can_compile_fullgraph = getattr(model, "_can_compile_fullgraph", False)
147+
self._tp_size = getattr(model, "_tp_size", None)
146148
self._dtype = self.model.dtype if self.model.dtype is not None else torch.float32
147149
self.use_cache = kwargs.get("use_cache", False)
148150
self.model_save_dir = model_save_dir
@@ -182,7 +184,7 @@ def from_pretrained(
182184
model = cls.auto_model_class.from_pretrained(model_id, **kwargs)
183185
if getattr(model.config, "torchscript", False):
184186
raise ValueError("IPEXModel is no longer support torchscript models.")
185-
return cls(model, config=kwargs.pop("config", model.config), **kwargs)
187+
return cls(model, config=kwargs.pop("config", model.config), model_save_dir=model_id, **kwargs)
186188

187189
def _save_pretrained(self, save_directory: Union[str, Path]):
188190
self.model.save_pretrained(save_directory, safe_serialization=False)
@@ -207,6 +209,13 @@ def device(self) -> torch.device:
207209
def dtype(self) -> torch.dtype:
208210
return self._dtype
209211

212+
@property
213+
def tp_size(self):
214+
"""
215+
Returns the model's tensor parallelism degree.
216+
"""
217+
return self._tp_size
218+
210219
@property
211220
def model_dtype(self):
212221
logger.warning(
@@ -311,10 +320,15 @@ def __init__(
311320

312321
self.generation_config = GenerationConfig.from_model_config(self.config)
313322
try:
314-
self.model_cls = get_class_from_dynamic_module(
315-
self.config.auto_map["AutoModelForCausalLM"], model_save_dir
316-
)
317-
except AttributeError:
323+
# Use model_save_dir if available, otherwise use config's name_or_path
324+
pretrained_model_name_or_path = model_save_dir or getattr(self.config, "_name_or_path", None)
325+
if pretrained_model_name_or_path is not None and hasattr(self.config, "auto_map"):
326+
self.model_cls = get_class_from_dynamic_module(
327+
self.config.auto_map["AutoModelForCausalLM"], pretrained_model_name_or_path
328+
)
329+
else:
330+
self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping)
331+
except (AttributeError, KeyError):
318332
self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping)
319333

320334
if hasattr(self.model_cls, "_convert_to_standard_cache"):
@@ -497,10 +511,15 @@ def __init__(
497511

498512
self.generation_config = GenerationConfig.from_model_config(self.config)
499513
try:
500-
self.model_cls = get_class_from_dynamic_module(
501-
self.config.auto_map["AutoModelForSeq2SeqLM"], model_save_dir
502-
)
503-
except AttributeError:
514+
# Use model_save_dir if available, otherwise use config's name_or_path
515+
pretrained_model_name_or_path = model_save_dir or getattr(self.config, "_name_or_path", None)
516+
if pretrained_model_name_or_path is not None and hasattr(self.config, "auto_map"):
517+
self.model_cls = get_class_from_dynamic_module(
518+
self.config.auto_map["AutoModelForSeq2SeqLM"], pretrained_model_name_or_path
519+
)
520+
else:
521+
self.model_cls = get_model_class(self.config, AutoModelForSeq2SeqLM._model_mapping)
522+
except (AttributeError, KeyError):
504523
self.model_cls = get_model_class(self.config, AutoModelForSeq2SeqLM._model_mapping)
505524

506525
if hasattr(self.model_cls, "_convert_to_standard_cache"):

0 commit comments

Comments
 (0)