Skip to content

Commit 966a318

Browse files
authored
Patch (#1200)
* upgrade transformers to 4.49 for patching models Signed-off-by: jiqing-feng <[email protected]> * update setup Signed-off-by: jiqing-feng <[email protected]> * disable linear fusion when using compile Signed-off-by: jiqing-feng <[email protected]> * use max-autotune Signed-off-by: jiqing-feng <[email protected]> * fix compile param Signed-off-by: jiqing-feng <[email protected]> * fix tests Signed-off-by: jiqing-feng <[email protected]> * disable max-autotune Signed-off-by: jiqing-feng <[email protected]> * make compile as a static method Signed-off-by: jiqing-feng <[email protected]> * fix opt test Signed-off-by: jiqing-feng <[email protected]> --------- Signed-off-by: jiqing-feng <[email protected]>
1 parent d4bd848 commit 966a318

File tree

3 files changed

+45
-36
lines changed

3 files changed

+45
-36
lines changed

optimum/exporters/ipex/modeling_utils.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def _llama_model_forward(
229229
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
230230
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
231231
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
232-
max_input_lens = input_lens.max().item()
232+
max_input_lens = input_lens.max()
233233

234234
if past_key_values_length == 0 and past_key_values is not None:
235235
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -357,7 +357,7 @@ def _falcon_model_forward(
357357
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
358358
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
359359
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
360-
max_input_lens = input_lens.max().item()
360+
max_input_lens = input_lens.max()
361361

362362
if past_key_values_length == 0 and past_key_values is not None:
363363
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -499,7 +499,7 @@ def _gpt2_model_forward(
499499
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
500500
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
501501
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
502-
max_input_lens = input_lens.max().item()
502+
max_input_lens = input_lens.max()
503503

504504
if past_length == 0 and past_key_values is not None:
505505
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -635,7 +635,7 @@ def _qwen2_model_forward(
635635
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
636636
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
637637
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
638-
max_input_lens = input_lens.max().item()
638+
max_input_lens = input_lens.max()
639639

640640
if past_key_values_length == 0 and past_key_values is not None:
641641
# first token, remove the padding from hidden_states, varlen do not accept attention mask
@@ -754,11 +754,11 @@ def attention_interface(
754754
if past_key_value is None:
755755
n_rep = query.shape[1] // key.shape[1]
756756
attn_output = torch.nn.functional.scaled_dot_product_attention(
757-
query.reshape(input_lens.shape[0], input_lens.max().item(), -1, query.shape[-1]).transpose(1, 2),
758-
key.reshape(input_lens.shape[0], input_lens.max().item(), -1, key.shape[-1])
757+
query.reshape(input_lens.shape[0], input_lens.max(), -1, query.shape[-1]).transpose(1, 2),
758+
key.reshape(input_lens.shape[0], input_lens.max(), -1, key.shape[-1])
759759
.transpose(1, 2)
760760
.repeat_interleave(n_rep, 1),
761-
value.reshape(input_lens.shape[0], input_lens.max().item(), -1, value.shape[-1])
761+
value.reshape(input_lens.shape[0], input_lens.max(), -1, value.shape[-1])
762762
.transpose(1, 2)
763763
.repeat_interleave(n_rep, 1),
764764
attn_mask=attention_mask,
@@ -885,13 +885,11 @@ def __init__(self, module, device, config) -> None:
885885
self.q_slice = self.q_proj.weight.shape[0]
886886
self.k_slice = self.q_slice + self.k_proj.weight.shape[0]
887887
self.v_slice = self.k_slice + self.v_proj.weight.shape[0]
888-
if self.module_device.type == "cpu":
889-
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
888+
if not config.compile and module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
889+
if self.module_device.type == "cpu":
890890
self.mha_linear_add = LinearAdd(module.o_proj)
891-
892891
elif self.module_device.type == "xpu":
893-
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
894-
self.mha_linear_add = XPULinearAdd(module.o_proj)
892+
self.mha_linear_add = XPULinearAdd(module.o_proj)
895893

896894
def qkv_gemm(self, hidden_states):
897895
if hasattr(self, "concat_qkv"):
@@ -935,7 +933,7 @@ class _IPEXGPT2Attention(_IPEXAttention):
935933
def __init__(self, module, device, config) -> None:
936934
super().__init__(module, device, config)
937935
_setattr_from_module(self, module)
938-
if getattr(config, "quantization_config", None) is None:
936+
if not config.compile and getattr(config, "quantization_config", None) is None:
939937
self.c_attn_linear = nn.Linear(self.c_attn.weight.shape[0], self.c_attn.weight.shape[1])
940938
self.c_attn_linear.weight = nn.Parameter(self.c_attn.weight.t())
941939
self.c_attn_linear.bias = self.c_attn.bias
@@ -979,7 +977,7 @@ def __init__(self, module, device, config) -> None:
979977
_setattr_from_module(self, module)
980978
self.config = config
981979
self.module_device = device
982-
if getattr(config, "quantization_config", None) is None:
980+
if not config.compile and getattr(config, "quantization_config", None) is None:
983981
if self.module_device.type == "cpu":
984982
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
985983
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
@@ -1012,7 +1010,7 @@ def __init__(self, module, device, config) -> None:
10121010
_setattr_from_module(self, module)
10131011
self.config = config
10141012
self.module_device = device
1015-
if getattr(config, "quantization_config", None) is None:
1013+
if not config.compile and getattr(config, "quantization_config", None) is None:
10161014
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
10171015
if self.module_device.type == "cpu":
10181016
self.linear_gelu = LinearGelu(module.dense_h_to_4h)
@@ -1052,7 +1050,7 @@ def __init__(self, module, device, config) -> None:
10521050
self.config = config
10531051
self.module_device = device
10541052

1055-
if getattr(config, "quantization_config", None) is None:
1053+
if not config.compile and getattr(config, "quantization_config", None) is None:
10561054
self.c_fc_linear = nn.Linear(self.c_fc.weight.shape[0], self.c_fc.weight.shape[1])
10571055
self.c_fc_linear.weight = nn.Parameter(self.c_fc.weight.t())
10581056
self.c_fc_linear.bias = self.c_fc.bias
@@ -1061,11 +1059,8 @@ def __init__(self, module, device, config) -> None:
10611059
self.c_proj_linear.bias = self.c_proj.bias
10621060
if self.module_device.type == "cpu":
10631061
self.linear_new_gelu = LinearNewGelu(self.c_fc_linear)
1064-
1065-
if self.module_device.type == "cpu":
10661062
if self.c_proj_linear not in ["LinearAllreduce"]:
10671063
self.linear_add = LinearAdd(self.c_proj_linear)
1068-
10691064
elif self.module_device.type == "xpu":
10701065
if self.c_proj_linear not in ["LinearAllreduce"]:
10711066
self.linear_add = XPULinearAdd(self.c_proj_linear)
@@ -1237,7 +1232,7 @@ def __init__(self, module, device, config):
12371232
super().__init__()
12381233
_setattr_from_module(self, module)
12391234
self.module_device = device
1240-
if getattr(config, "quantization_config", None) is None:
1235+
if not config.compile and getattr(config, "quantization_config", None) is None:
12411236
if self.module_device.type == "cpu":
12421237
self.linear_gelu = LinearGelu(module.dense)
12431238
elif self.module_device.type == "xpu":

optimum/intel/ipex/modeling_base.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def __init__(
146146
self.use_cache = kwargs.get("use_cache", False)
147147
self.model_save_dir = model_save_dir
148148
self._add_patch = _is_patched_with_ipex(model, self.export_feature, self.use_cache)
149-
self.compiled = False
149+
self.model.config.compile = self.can_compile(self.model, self.use_cache)
150150

151151
self.input_names = set(inspect.signature(model.forward).parameters)
152152

@@ -158,9 +158,10 @@ def __init__(
158158
if hasattr(self.auto_model_class, "register"):
159159
self.auto_model_class.register(AutoConfig, self.__class__)
160160

161-
self.maybe_apply_torch_compile()
161+
if getattr(self.model.config, "compile", False):
162+
self.apply_torch_compile()
162163

163-
if warmup and not self.compiled:
164+
if warmup and not getattr(self.model.config, "compile", False):
164165
self._init_warmup()
165166

166167
@classmethod
@@ -231,24 +232,28 @@ def to(self, device: Union[torch.device, str]):
231232
def can_generate(self):
232233
return isinstance(self, GenerationMixin)
233234

234-
def maybe_apply_torch_compile(self):
235+
@staticmethod
236+
def can_compile(model, use_cache):
235237
if (
236-
self.model.device.type != "cpu"
237-
or self.config.model_type in _COMPILE_NOT_READY_MODEL_TYPES
238+
model.device.type != "cpu"
239+
or model.config.model_type in _COMPILE_NOT_READY_MODEL_TYPES
238240
or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_COMPILE)
239-
or getattr(self.config, "quantization_config", None)
241+
or getattr(model.config, "quantization_config", None)
240242
):
241-
return
242-
if self.use_cache and not self._supports_static_cache:
243-
return
243+
return False
244+
if use_cache and not model._supports_static_cache:
245+
return False
246+
247+
return True
248+
249+
def apply_torch_compile(self):
244250
from torch._inductor import config as inductor_config
245251

246252
# System level optimization
247253
inductor_config.cpp_wrapper = True
248254
os.environ["TORCHINDUCTOR_FREEZING"] = "1"
249255
logger.info("Enable torch.compile optimization")
250256
self.model.forward = torch.compile(self.model.forward)
251-
self.compiled = True
252257

253258
def _init_warmup(self):
254259
inputs = prepare_jit_inputs(self.model, self.export_feature, False)
@@ -328,7 +333,7 @@ def __init__(
328333
if hasattr(self.model_cls, "_convert_to_bloom_cache"):
329334
self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache
330335

331-
if warmup and not self.compiled:
336+
if warmup and not getattr(self.model.config, "compile", False):
332337
self._init_warmup()
333338

334339
@torch.no_grad()
@@ -348,7 +353,11 @@ def _prepare_generation_config(
348353
kwargs["use_cache"] = self.use_cache
349354
generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
350355
generation_method = generation_config.get_generation_mode().value
351-
if self.compiled and generation_config.cache_implementation != "ipex_paged" and self._supports_static_cache:
356+
if (
357+
getattr(self.model.config, "compile", False)
358+
and generation_config.cache_implementation != "ipex_paged"
359+
and self._supports_static_cache
360+
):
352361
# Use static cache for torch compile
353362
generation_config.cache_implementation = "static"
354363
if generation_method not in _IPEX_EXPORTED_GENERATION_METHODS:
@@ -459,7 +468,7 @@ def __init__(
459468
if hasattr(self.model_cls, "_convert_to_standard_cache"):
460469
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache
461470

462-
if warmup and not self.compiled:
471+
if warmup and not getattr(self.model.config, "compile", False):
463472
self._init_warmup()
464473

465474
@torch.no_grad()
@@ -476,7 +485,7 @@ def _prepare_generation_config(
476485
) -> Tuple[GenerationConfig, Dict]:
477486
generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
478487
# Use static cache for torch.compile
479-
if self.compiled:
488+
if getattr(self.model.config, "compile", False):
480489
generation_config.cache_implementation = "static"
481490

482491
return generation_config, model_kwargs

tests/ipex/test_modeling.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,12 @@ def test_ipex_beam_search(self, test_name, model_arch, use_cache):
374374
model_id, use_cache=use_cache, torch_dtype=dtype, device_map=DEVICE
375375
)
376376
# It will be removed when torch 2.6 released
377-
if model_arch == "opt" and not use_cache and model.compiled and is_torch_version("<", "2.6.0"):
377+
if (
378+
model_arch == "opt"
379+
and not use_cache
380+
and getattr(model.config, "compile", False)
381+
and is_torch_version("<", "2.6.0")
382+
):
378383
return
379384
if use_cache and model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES:
380385
self.assertTrue(model.add_patch)

0 commit comments

Comments
 (0)