Skip to content

Commit 5db1ac7

Browse files
authored
Add support assisted decoding in ipex 2.4 (#823)
* support assisted decoding in ipex 2.5 * Update optimum/intel/ipex/modeling_base.py Co-authored-by: Ella Charlaix <[email protected]> * fix tests fail * fix style * ipex onnx config * patch before generate and un-patch after generate * only patch functions in assisted decoding * try and cache the genration result and do un-patch * raise error * fix style * ipex 2.4 supports assisted decoding * fix inputs * fix generate * enable assisted decoding tests * more tests on assisted decoding * fix config name * unpatch target model's generation
1 parent 8a015a6 commit 5db1ac7

File tree

3 files changed

+173
-17
lines changed

3 files changed

+173
-17
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional, Tuple
16+
17+
from optimum.exporters.onnx.model_configs import (
18+
FalconOnnxConfig,
19+
GPT2OnnxConfig,
20+
LlamaOnnxConfig,
21+
)
22+
from optimum.utils import DEFAULT_DUMMY_SHAPES
23+
from optimum.utils.input_generators import DummyPastKeyValuesGenerator, DummyTextInputGenerator
24+
from optimum.utils.normalized_config import NormalizedTextConfig
25+
26+
27+
DEFAULT_DUMMY_SHAPES["batch_size"] = 1
28+
29+
30+
class IPEXDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
31+
def __init__(
32+
self,
33+
task: str,
34+
normalized_config: NormalizedTextConfig,
35+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
36+
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
37+
random_batch_size_range: Optional[Tuple[int, int]] = None,
38+
random_sequence_length_range: Optional[Tuple[int, int]] = None,
39+
**kwargs,
40+
):
41+
super().__init__(
42+
task=task,
43+
normalized_config=normalized_config,
44+
batch_size=batch_size,
45+
sequence_length=sequence_length,
46+
random_batch_size_range=random_batch_size_range,
47+
random_sequence_length_range=random_sequence_length_range,
48+
)
49+
self.num_key_value_heads = getattr(normalized_config, "num_key_value_heads", 1)
50+
self.max_position_embeddings = normalized_config.max_position_embeddings
51+
52+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
53+
shape_init = (1, self.sequence_length, self.sequence_length, 1)
54+
shape_beam_idx_tmp = (self.max_position_embeddings, self.batch_size)
55+
shape_kv = (
56+
self.max_position_embeddings,
57+
self.batch_size,
58+
self.num_key_value_heads,
59+
self.hidden_size // self.num_attention_heads,
60+
)
61+
return [
62+
(
63+
self.random_int_tensor(shape_init, max_value=1, framework=framework).contiguous(),
64+
self.random_float_tensor(shape_kv, framework=framework, dtype=float_dtype).contiguous(),
65+
self.random_float_tensor(shape_kv, framework=framework, dtype=float_dtype).contiguous(),
66+
self.random_int_tensor(shape_beam_idx_tmp, max_value=1, framework=framework).contiguous(),
67+
)
68+
for _ in range(self.num_layers)
69+
]
70+
71+
72+
class IPEXDummyTextInputGenerator(DummyTextInputGenerator):
73+
def __init__(
74+
self,
75+
task: str,
76+
normalized_config: NormalizedTextConfig,
77+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
78+
**kwargs,
79+
):
80+
super().__init__(task, normalized_config, batch_size, **kwargs)
81+
82+
83+
class LlamaIPEXConfig(LlamaOnnxConfig):
84+
DUMMY_INPUT_GENERATOR_CLASSES = (IPEXDummyTextInputGenerator, IPEXDummyPastKeyValuesGenerator)
85+
DUMMY_PKV_GENERATOR_CLASS = IPEXDummyPastKeyValuesGenerator
86+
87+
88+
class FalconIPEXConfig(FalconOnnxConfig):
89+
DUMMY_INPUT_GENERATOR_CLASSES = (IPEXDummyTextInputGenerator, IPEXDummyPastKeyValuesGenerator)
90+
DUMMY_PKV_GENERATOR_CLASS = IPEXDummyPastKeyValuesGenerator
91+
92+
93+
class GPT2IPEXConfig(GPT2OnnxConfig):
94+
DUMMY_INPUT_GENERATOR_CLASSES = (IPEXDummyTextInputGenerator, IPEXDummyPastKeyValuesGenerator)
95+
DUMMY_PKV_GENERATOR_CLASS = IPEXDummyPastKeyValuesGenerator
96+
97+
98+
ipex_onnx_config = {"llama": LlamaIPEXConfig, "falcon": FalconIPEXConfig, "gpt2": GPT2IPEXConfig}

optimum/intel/ipex/modeling_base.py

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323

2424
import intel_extension_for_pytorch as ipex
2525
import torch
26+
import transformers
2627
from huggingface_hub import hf_hub_download
2728
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
2829
from intel_extension_for_pytorch.cpu._auto_kernel_selection import _enable_tpp
29-
from intel_extension_for_pytorch.transformers.optimize import get_dummy_input
3030
from transformers import (
3131
AutoConfig,
3232
AutoModel,
@@ -43,20 +43,24 @@
4343
is_torch_xpu_available,
4444
)
4545
from transformers.dynamic_module_utils import get_class_from_dynamic_module
46+
from transformers.generation.candidate_generator import _crop_past_key_values
4647
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
4748
from transformers.models.auto.auto_factory import _get_model_class as get_model_class
4849
from transformers.utils import WEIGHTS_NAME
4950

5051
from optimum.exporters import TasksManager
52+
from optimum.exporters.tasks import make_backend_config_constructor_for_task
5153
from optimum.modeling_base import OptimizedModel
5254
from optimum.utils import NormalizedConfigManager
5355

56+
from ...exporters.ipex.model_config import ipex_onnx_config
5457
from ...exporters.ipex.model_patcher import (
5558
_IPEX_EXPORTED_GENERATION_TASKS,
5659
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
5760
_patch_model,
5861
)
59-
from ..generation.modeling import prepare_jit_inputs
62+
from ..generation.modeling import get_float_type
63+
from ..utils.constant import _TASK_ALIASES
6064
from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version
6165
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, recursive_to_device
6266

@@ -86,10 +90,35 @@ def _is_patched_with_ipex(model, task):
8690

8791

8892
def _prepare_inputs_for_ipex_model(model, task, use_cache):
89-
if task in _IPEX_EXPORTED_GENERATION_TASKS and _is_patched_with_ipex(model, task):
90-
return get_dummy_input(model, return_dict=True)
93+
task = _TASK_ALIASES.get(task, task)
94+
signature = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.__call__)
95+
if _is_patched_with_ipex(model, task) and model.config.model_type in ipex_onnx_config:
96+
onnx_config_class = make_backend_config_constructor_for_task(
97+
ipex_onnx_config[model.config.model_type], task=task
98+
)
99+
else:
100+
onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
101+
float_dtype = get_float_type(model.dtype)
102+
if "text-generation" in task:
103+
onnx_config = onnx_config_class(
104+
model.config, use_past=use_cache, use_past_in_inputs=use_cache, float_dtype=float_dtype
105+
)
91106
else:
92-
return prepare_jit_inputs(model, task, use_cache)
107+
onnx_config = onnx_config_class(model.config)
108+
109+
dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt")
110+
111+
# Check attention_mask shape
112+
if _is_patched_with_ipex(model, task) and model.config.model_type in ipex_onnx_config and use_cache:
113+
past_len = dummy_inputs["past_key_values"][0][0].shape[-2]
114+
input_len = dummy_inputs["input_ids"].shape[-1]
115+
attention_len = dummy_inputs["attention_mask"].shape[-1]
116+
if attention_len != input_len + past_len:
117+
dummy_inputs["attention_mask"] = torch.ones([dummy_inputs["input_ids"].shape[0], input_len + past_len]).to(
118+
dummy_inputs["input_ids"].dtype
119+
)
120+
121+
return {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None}
93122

94123

95124
def ipex_jit_trace(model, task, use_cache):
@@ -103,11 +132,7 @@ def ipex_jit_trace(model, task, use_cache):
103132
sample_inputs = _prepare_inputs_for_ipex_model(model, task, use_cache)
104133

105134
model.config.return_dict = False
106-
107-
if "past_key_values" in sample_inputs:
108-
model.config.use_cache = use_cache
109-
if not use_cache:
110-
sample_inputs.pop("past_key_values")
135+
model.config.use_cache = use_cache
111136

112137
# Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755.
113138
# Only ipex >= 2.3.0 supports tpp. The tpp is only verified for llm in generation tasks.
@@ -372,7 +397,7 @@ def _init_warmup(self):
372397
# TODO : add warmup for IPEX exported model
373398
if not self._is_ipex_exported:
374399
use_cache = "past_key_values" in self.input_names
375-
dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache)
400+
dummy_inputs = _prepare_inputs_for_ipex_model(self, self.export_feature, use_cache)
376401
if self._device.type != "cpu":
377402
dummy_inputs = recursive_to_device(value=dummy_inputs, device=self._device)
378403
for _ in range(2):
@@ -652,11 +677,28 @@ def _prepare_generation_config(
652677
return generation_config, model_kwargs
653678

654679
def generate(self, *args, **kwargs):
655-
if self._is_ipex_exported and kwargs.get("assistant_model", None):
680+
if is_ipex_version("<", "2.4.0") and self._is_ipex_exported and kwargs.get("assistant_model", None):
656681
raise ValueError(
657-
f"Assisted decoding is not supported for patched models for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
682+
f"Assisted decoding is not supported for patched models if ipex < 2.4, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
658683
)
659-
return super().generate(*args, **kwargs)
684+
# Patch functions to support IAKV cache
685+
if self._is_ipex_exported and kwargs.get("assistant_model", None):
686+
transformers.generation.utils._crop_past_key_values = _ipex_crop_past_key_values
687+
elif self._is_ipex_exported:
688+
transformers.generation.candidate_generator._crop_past_key_values = _ipex_crop_past_key_values
689+
690+
try:
691+
result = super().generate(*args, **kwargs)
692+
except Exception as e:
693+
transformers.generation.utils._crop_past_key_values = _crop_past_key_values
694+
transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values
695+
raise e
696+
697+
if self._is_ipex_exported and kwargs.get("assistant_model", None):
698+
transformers.generation.utils._crop_past_key_values = _crop_past_key_values
699+
transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values
700+
701+
return result
660702

661703

662704
def _ipex_prepare_inputs_for_generation(
@@ -736,3 +778,16 @@ def _ipex_reorder_cache(
736778
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
737779
for layer_past in past_key_values
738780
)
781+
782+
783+
def _ipex_crop_past_key_values(model, past_key_values, max_length):
784+
if isinstance(model, IPEXModel) and _is_patched_with_ipex(model, "text-generation"):
785+
new_past_key_values = []
786+
for i in range(len(past_key_values)):
787+
pkv = []
788+
pkv.append(past_key_values[i][0][:, :max_length, :max_length, :])
789+
pkv += [past_key_values[i][_] for _ in range(1, 4)]
790+
new_past_key_values.append(tuple(pkv))
791+
new_past_key_values = tuple(new_past_key_values)
792+
return new_past_key_values
793+
return _crop_past_key_values(model, past_key_values, max_length)

tests/ipex/test_modeling.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,10 @@ def test_pipeline(self, model_arch):
281281
self.assertEqual(pipe.device, model.device)
282282
self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs))
283283

284-
# High optimized model llama is not supported assisted decoding for now.
285284
@parameterized.expand(SUPPORTED_ARCHITECTURES)
286285
def test_assisted_decoding(self, model_arch):
287-
# Patched models are not support assisted decoding for now.
288-
if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES:
286+
# Patched models are not support assisted decoding if ipex < 2.5.
287+
if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES and is_ipex_version("<", "2.4.0"):
289288
return
290289
model_id = MODEL_NAMES[model_arch]
291290
tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -296,11 +295,15 @@ def test_assisted_decoding(self, model_arch):
296295
ipex_output_assisted = ipex_model.generate(
297296
**tokens, do_sample=False, assistant_model=transformers_model, max_new_tokens=4
298297
)
298+
ipex_output_assisted_2 = ipex_model.generate(
299+
**tokens, do_sample=False, assistant_model=ipex_model, max_new_tokens=4
300+
)
299301
transformers_output = transformers_model.generate(**tokens, do_sample=False, max_new_tokens=4)
300302
transformers_output_assisted = transformers_model.generate(
301303
**tokens, do_sample=False, assistant_model=ipex_model, max_new_tokens=4
302304
)
303305
self.assertTrue(torch.equal(ipex_output, ipex_output_assisted))
306+
self.assertTrue(torch.equal(ipex_output, ipex_output_assisted_2))
304307
self.assertTrue(torch.equal(transformers_output, transformers_output_assisted))
305308

306309
@parameterized.expand(

0 commit comments

Comments
 (0)