Skip to content

Commit f027f08

Browse files
Add Mistral modeling optimization support for ipex (#1269)
* add mistral type model support Signed-off-by: Liu, Kaixuan <[email protected]> * fix typo Signed-off-by: Liu, Kaixuan <[email protected]> * Update optimum/exporters/ipex/modeling_utils.py Co-authored-by: Ella Charlaix <[email protected]> * add comments Signed-off-by: Liu, Kaixuan <[email protected]> * add `--fix` to ruff check Signed-off-by: Liu, Kaixuan <[email protected]> * fix bug Signed-off-by: Liu, Kaixuan <[email protected]> --------- Signed-off-by: Liu, Kaixuan <[email protected]> Co-authored-by: Ella Charlaix <[email protected]>
1 parent bfe8beb commit f027f08

File tree

3 files changed

+146
-3
lines changed

3 files changed

+146
-3
lines changed

optimum/exporters/ipex/model_patcher.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
LlamaModel,
2121
LlamaRMSNorm,
2222
)
23+
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralModel, MistralRMSNorm
2324
from transformers.models.qwen2.modeling_qwen2 import (
2425
Qwen2DecoderLayer,
2526
Qwen2Model,
@@ -39,8 +40,10 @@
3940
_IPEXGPT2Block,
4041
_IPEXIntermediate,
4142
_IPEXLlamaDecoderLayer,
43+
_IPEXMistralDecoderLayer,
4244
_IPEXQwen2DecoderLayer,
4345
_llama_model_forward,
46+
_mistral_model_forward,
4447
_qwen2_model_forward,
4548
)
4649

@@ -132,6 +135,18 @@ def _patch_qwen2_model(model):
132135
return model
133136

134137

138+
def _patch_mistral_model(model):
139+
"""
140+
Patch mistral model:
141+
1. Use IPEX rope and paged cache
142+
2. Linear fusion with (Linear + Add)
143+
"""
144+
convert_functions(model, MistralModel, "forward", _mistral_model_forward)
145+
convert_functions(model, MistralRMSNorm, "forward", _ipex_rms_layer_norm_forward)
146+
convert_class(model, MistralDecoderLayer, _IPEXMistralDecoderLayer, model.device, model.config)
147+
return model
148+
149+
135150
def _patch_bert_model(model):
136151
"""
137152
Patch bert model:
@@ -167,6 +182,8 @@ def _patch_model(model):
167182
model = _patch_gpt2_model(model)
168183
elif model.config.model_type == "qwen2":
169184
model = _patch_qwen2_model(model)
185+
elif model.config.model_type == "mistral":
186+
model = _patch_mistral_model(model)
170187
elif model.config.model_type == "bert":
171188
model = _patch_bert_model(model)
172189
elif model.config.model_type == "vit":

optimum/exporters/ipex/modeling_utils.py

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,126 @@ def _qwen2_model_forward(
630630
return output if return_dict else output.to_tuple()
631631

632632

633+
# Adapted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/mistral/modeling_mistral.py#L459
634+
def _mistral_model_forward(
635+
self,
636+
input_ids: torch.LongTensor = None,
637+
attention_mask: Optional[torch.Tensor] = None,
638+
position_ids: Optional[torch.LongTensor] = None,
639+
past_key_values: Optional[Cache] = None,
640+
inputs_embeds: Optional[torch.FloatTensor] = None,
641+
use_cache: Optional[bool] = None,
642+
output_attentions: Optional[bool] = None,
643+
output_hidden_states: Optional[bool] = None,
644+
return_dict: Optional[bool] = None,
645+
cache_position: Optional[torch.LongTensor] = None,
646+
**kwargs,
647+
) -> Union[Tuple, BaseModelOutputWithPast]:
648+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
649+
output_hidden_states = (
650+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
651+
)
652+
use_cache = use_cache if use_cache is not None else self.config.use_cache
653+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
654+
655+
if inputs_embeds is None:
656+
inputs_embeds = self.embed_tokens(input_ids)
657+
658+
batch_size, seq_length = inputs_embeds.shape[:2]
659+
device = input_ids.device if input_ids is not None else inputs_embeds.device
660+
661+
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
662+
if cache_position is None:
663+
cache_position = torch.arange(
664+
past_key_values_length, past_key_values_length + inputs_embeds.shape[1], device=device
665+
)
666+
667+
if position_ids is None:
668+
position_ids = cache_position.unsqueeze(0)
669+
670+
causal_mask = self._update_causal_mask(
671+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
672+
)
673+
674+
hidden_states = inputs_embeds
675+
676+
# create position embeddings to be shared across the decoder layers
677+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
678+
679+
# part of the code that was modified below
680+
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
681+
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
682+
query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int()
683+
max_input_lens = input_lens.max()
684+
cos = position_embeddings[0]
685+
sin = position_embeddings[1]
686+
if past_key_values_length == 0 and past_key_values is not None:
687+
# first token, remove the padding from hidden_states, varlen do not accept attention mask
688+
hidden_states_copy = hidden_states
689+
index = attention_mask.view(-1) != 0
690+
hidden_states = (hidden_states.view(-1, hidden_states.shape[-1]))[index]
691+
cos = (cos.reshape(-1, cos.shape[-1]))[index]
692+
sin = (sin.reshape(-1, sin.shape[-1]))[index]
693+
position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1))
694+
else:
695+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
696+
# TODO: remove this WA after IPEX 2.7
697+
if device.type == "xpu":
698+
cos = cos.reshape(-1, cos.shape[-1])
699+
sin = sin.reshape(-1, sin.shape[-1])
700+
position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1))
701+
if past_key_values is None:
702+
attention_mask = causal_mask
703+
# part of the code that was modified above
704+
705+
# decoder layers
706+
all_hidden_states = () if output_hidden_states else None
707+
all_self_attns = () if output_attentions else None
708+
709+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
710+
if output_hidden_states:
711+
all_hidden_states += (hidden_states,)
712+
713+
layer_outputs = decoder_layer(
714+
hidden_states,
715+
attention_mask=attention_mask,
716+
position_ids=position_ids,
717+
past_key_value=past_key_values,
718+
output_attentions=output_attentions,
719+
use_cache=use_cache,
720+
cache_position=cache_position,
721+
position_embeddings=position_embeddings,
722+
input_lens=input_lens,
723+
max_input_lens=max_input_lens,
724+
seq_len_tensor=seq_len_tensor,
725+
query_len_tensor=query_len_tensor,
726+
**kwargs,
727+
)
728+
729+
hidden_states = layer_outputs[0]
730+
731+
if output_attentions:
732+
all_self_attns += (layer_outputs[1],)
733+
734+
hidden_states = self.norm(hidden_states)
735+
736+
if hidden_states.shape[0] != batch_size * seq_length:
737+
(hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states
738+
hidden_states = hidden_states_copy
739+
hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1])
740+
# add hidden states from the last decoder layer
741+
if output_hidden_states:
742+
all_hidden_states += (hidden_states,)
743+
744+
output = BaseModelOutputWithPast(
745+
last_hidden_state=hidden_states,
746+
past_key_values=past_key_values if use_cache else None,
747+
hidden_states=all_hidden_states,
748+
attentions=all_self_attns,
749+
)
750+
return output if return_dict else output.to_tuple()
751+
752+
633753
class _IPEXAttention(nn.Module):
634754
def __init__(self, module, device, config) -> None:
635755
super().__init__()
@@ -904,7 +1024,8 @@ def __init__(self, module, device, config) -> None:
9041024
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
9051025
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
9061026
self.mlp_linear_add = LinearAdd(module.down_proj)
907-
self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj)
1027+
if isinstance(self.act_fn, nn.SiLU):
1028+
self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj)
9081029

9091030
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs):
9101031
if hasattr(self, "linear_silu_mul"):
@@ -1136,6 +1257,11 @@ def __init__(self, *args, **kwargs):
11361257
super().__init__(*args, **kwargs)
11371258

11381259

1260+
class _IPEXMistralDecoderLayer(_IPEXLlamaDecoderLayer):
1261+
def __init__(self, *args, **kwargs):
1262+
super().__init__(*args, **kwargs)
1263+
1264+
11391265
# Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/bert/modeling_bert.py#L524
11401266
class _IPEXIntermediate(nn.Module):
11411267
def __init__(self, module, device, config):

optimum/intel/ipex/modeling_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,14 @@
6161
logger = logging.getLogger(__name__)
6262

6363

64-
_IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2", "qwen2")
64+
_IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2", "qwen2", "mistral")
6565
_IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation")
6666
_IPEX_MINIMUM_VERSION_FOR_COMPILE = "2.5.0"
6767
# Page attention model cannot use torch.compile for now.
6868
if is_torch_version("<", "2.6"):
6969
_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "gpt_neox", "beit", "llama", "falcon", "gpt2", "qwen2")
7070
else:
71-
_COMPILE_NOT_READY_MODEL_TYPES = ("llama", "falcon", "gpt2", "qwen2")
71+
_COMPILE_NOT_READY_MODEL_TYPES = ("llama", "falcon", "gpt2", "qwen2", "mistral")
7272

7373

7474
try:

0 commit comments

Comments
 (0)