2323
2424import intel_extension_for_pytorch as ipex
2525import torch
26+ import transformers
2627from huggingface_hub import hf_hub_download
2728from huggingface_hub .constants import HUGGINGFACE_HUB_CACHE
2829from intel_extension_for_pytorch .cpu ._auto_kernel_selection import _enable_tpp
29- from intel_extension_for_pytorch .transformers .optimize import get_dummy_input
3030from transformers import (
3131 AutoConfig ,
3232 AutoModel ,
4343 is_torch_xpu_available ,
4444)
4545from transformers .dynamic_module_utils import get_class_from_dynamic_module
46+ from transformers .generation .candidate_generator import _crop_past_key_values
4647from transformers .modeling_outputs import CausalLMOutputWithPast , ModelOutput
4748from transformers .models .auto .auto_factory import _get_model_class as get_model_class
4849from transformers .utils import WEIGHTS_NAME
4950
5051from optimum .exporters import TasksManager
52+ from optimum .exporters .tasks import make_backend_config_constructor_for_task
5153from optimum .modeling_base import OptimizedModel
5254from optimum .utils import NormalizedConfigManager
5355
56+ from ...exporters .ipex .model_config import ipex_onnx_config
5457from ...exporters .ipex .model_patcher import (
5558 _IPEX_EXPORTED_GENERATION_TASKS ,
5659 _IPEX_MINIMUM_VERSION_FOR_PATCHING ,
5760 _patch_model ,
5861)
59- from ..generation .modeling import prepare_jit_inputs
62+ from ..generation .modeling import get_float_type
63+ from ..utils .constant import _TASK_ALIASES
6064from ..utils .import_utils import is_ipex_version , is_torch_version , is_transformers_version
6165from ..utils .modeling_utils import MULTI_QUERY_ATTN_MODELS , recursive_to_device
6266
@@ -86,10 +90,35 @@ def _is_patched_with_ipex(model, task):
8690
8791
8892def _prepare_inputs_for_ipex_model (model , task , use_cache ):
89- if task in _IPEX_EXPORTED_GENERATION_TASKS and _is_patched_with_ipex (model , task ):
90- return get_dummy_input (model , return_dict = True )
93+ task = _TASK_ALIASES .get (task , task )
94+ signature = inspect .signature (model .forward ) if hasattr (model , "forward" ) else inspect .signature (model .__call__ )
95+ if _is_patched_with_ipex (model , task ) and model .config .model_type in ipex_onnx_config :
96+ onnx_config_class = make_backend_config_constructor_for_task (
97+ ipex_onnx_config [model .config .model_type ], task = task
98+ )
99+ else :
100+ onnx_config_class = TasksManager .get_exporter_config_constructor (model = model , exporter = "onnx" , task = task )
101+ float_dtype = get_float_type (model .dtype )
102+ if "text-generation" in task :
103+ onnx_config = onnx_config_class (
104+ model .config , use_past = use_cache , use_past_in_inputs = use_cache , float_dtype = float_dtype
105+ )
91106 else :
92- return prepare_jit_inputs (model , task , use_cache )
107+ onnx_config = onnx_config_class (model .config )
108+
109+ dummy_inputs = onnx_config .generate_dummy_inputs (framework = "pt" )
110+
111+ # Check attention_mask shape
112+ if _is_patched_with_ipex (model , task ) and model .config .model_type in ipex_onnx_config and use_cache :
113+ past_len = dummy_inputs ["past_key_values" ][0 ][0 ].shape [- 2 ]
114+ input_len = dummy_inputs ["input_ids" ].shape [- 1 ]
115+ attention_len = dummy_inputs ["attention_mask" ].shape [- 1 ]
116+ if attention_len != input_len + past_len :
117+ dummy_inputs ["attention_mask" ] = torch .ones ([dummy_inputs ["input_ids" ].shape [0 ], input_len + past_len ]).to (
118+ dummy_inputs ["input_ids" ].dtype
119+ )
120+
121+ return {key : dummy_inputs [key ] for key in signature .parameters if dummy_inputs .get (key , None ) is not None }
93122
94123
95124def ipex_jit_trace (model , task , use_cache ):
@@ -103,11 +132,7 @@ def ipex_jit_trace(model, task, use_cache):
103132 sample_inputs = _prepare_inputs_for_ipex_model (model , task , use_cache )
104133
105134 model .config .return_dict = False
106-
107- if "past_key_values" in sample_inputs :
108- model .config .use_cache = use_cache
109- if not use_cache :
110- sample_inputs .pop ("past_key_values" )
135+ model .config .use_cache = use_cache
111136
112137 # Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755.
113138 # Only ipex >= 2.3.0 supports tpp. The tpp is only verified for llm in generation tasks.
@@ -372,7 +397,7 @@ def _init_warmup(self):
372397 # TODO : add warmup for IPEX exported model
373398 if not self ._is_ipex_exported :
374399 use_cache = "past_key_values" in self .input_names
375- dummy_inputs = prepare_jit_inputs (self , self .export_feature , use_cache )
400+ dummy_inputs = _prepare_inputs_for_ipex_model (self , self .export_feature , use_cache )
376401 if self ._device .type != "cpu" :
377402 dummy_inputs = recursive_to_device (value = dummy_inputs , device = self ._device )
378403 for _ in range (2 ):
@@ -652,11 +677,28 @@ def _prepare_generation_config(
652677 return generation_config , model_kwargs
653678
654679 def generate (self , * args , ** kwargs ):
655- if self ._is_ipex_exported and kwargs .get ("assistant_model" , None ):
680+ if is_ipex_version ( "<" , "2.4.0" ) and self ._is_ipex_exported and kwargs .get ("assistant_model" , None ):
656681 raise ValueError (
657- f"Assisted decoding is not supported for patched models for now , support methods are { _IPEX_EXPORTED_GENERATION_METHODS } "
682+ f"Assisted decoding is not supported for patched models if ipex < 2.4 , support methods are { _IPEX_EXPORTED_GENERATION_METHODS } "
658683 )
659- return super ().generate (* args , ** kwargs )
684+ # Patch functions to support IAKV cache
685+ if self ._is_ipex_exported and kwargs .get ("assistant_model" , None ):
686+ transformers .generation .utils ._crop_past_key_values = _ipex_crop_past_key_values
687+ elif self ._is_ipex_exported :
688+ transformers .generation .candidate_generator ._crop_past_key_values = _ipex_crop_past_key_values
689+
690+ try :
691+ result = super ().generate (* args , ** kwargs )
692+ except Exception as e :
693+ transformers .generation .utils ._crop_past_key_values = _crop_past_key_values
694+ transformers .generation .candidate_generator ._crop_past_key_values = _crop_past_key_values
695+ raise e
696+
697+ if self ._is_ipex_exported and kwargs .get ("assistant_model" , None ):
698+ transformers .generation .utils ._crop_past_key_values = _crop_past_key_values
699+ transformers .generation .candidate_generator ._crop_past_key_values = _crop_past_key_values
700+
701+ return result
660702
661703
662704def _ipex_prepare_inputs_for_generation (
@@ -736,3 +778,16 @@ def _ipex_reorder_cache(
736778 tuple (past_state .index_select (0 , beam_idx .to (past_state .device )) for past_state in layer_past )
737779 for layer_past in past_key_values
738780 )
781+
782+
783+ def _ipex_crop_past_key_values (model , past_key_values , max_length ):
784+ if isinstance (model , IPEXModel ) and _is_patched_with_ipex (model , "text-generation" ):
785+ new_past_key_values = []
786+ for i in range (len (past_key_values )):
787+ pkv = []
788+ pkv .append (past_key_values [i ][0 ][:, :max_length , :max_length , :])
789+ pkv += [past_key_values [i ][_ ] for _ in range (1 , 4 )]
790+ new_past_key_values .append (tuple (pkv ))
791+ new_past_key_values = tuple (new_past_key_values )
792+ return new_past_key_values
793+ return _crop_past_key_values (model , past_key_values , max_length )
0 commit comments