@@ -630,6 +630,126 @@ def _qwen2_model_forward(
630630 return output if return_dict else output .to_tuple ()
631631
632632
633+ # Adapted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/mistral/modeling_mistral.py#L459
634+ def _mistral_model_forward (
635+ self ,
636+ input_ids : torch .LongTensor = None ,
637+ attention_mask : Optional [torch .Tensor ] = None ,
638+ position_ids : Optional [torch .LongTensor ] = None ,
639+ past_key_values : Optional [Cache ] = None ,
640+ inputs_embeds : Optional [torch .FloatTensor ] = None ,
641+ use_cache : Optional [bool ] = None ,
642+ output_attentions : Optional [bool ] = None ,
643+ output_hidden_states : Optional [bool ] = None ,
644+ return_dict : Optional [bool ] = None ,
645+ cache_position : Optional [torch .LongTensor ] = None ,
646+ ** kwargs ,
647+ ) -> Union [Tuple , BaseModelOutputWithPast ]:
648+ output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
649+ output_hidden_states = (
650+ output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
651+ )
652+ use_cache = use_cache if use_cache is not None else self .config .use_cache
653+ return_dict = return_dict if return_dict is not None else self .config .use_return_dict
654+
655+ if inputs_embeds is None :
656+ inputs_embeds = self .embed_tokens (input_ids )
657+
658+ batch_size , seq_length = inputs_embeds .shape [:2 ]
659+ device = input_ids .device if input_ids is not None else inputs_embeds .device
660+
661+ past_key_values_length = past_key_values .get_seq_length () if past_key_values is not None else 0
662+ if cache_position is None :
663+ cache_position = torch .arange (
664+ past_key_values_length , past_key_values_length + inputs_embeds .shape [1 ], device = device
665+ )
666+
667+ if position_ids is None :
668+ position_ids = cache_position .unsqueeze (0 )
669+
670+ causal_mask = self ._update_causal_mask (
671+ attention_mask , inputs_embeds , cache_position , past_key_values , output_attentions
672+ )
673+
674+ hidden_states = inputs_embeds
675+
676+ # create position embeddings to be shared across the decoder layers
677+ position_embeddings = self .rotary_emb (hidden_states , position_ids )
678+
679+ # part of the code that was modified below
680+ input_lens = attention_mask .cumsum (- 1 )[:, - 1 ].to (torch .int32 )
681+ seq_len_tensor = torch .cat ((input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ()))
682+ query_len_tensor = torch .arange (seq_len_tensor .shape [0 ], device = device ).int ()
683+ max_input_lens = input_lens .max ()
684+ cos = position_embeddings [0 ]
685+ sin = position_embeddings [1 ]
686+ if past_key_values_length == 0 and past_key_values is not None :
687+ # first token, remove the padding from hidden_states, varlen do not accept attention mask
688+ hidden_states_copy = hidden_states
689+ index = attention_mask .view (- 1 ) != 0
690+ hidden_states = (hidden_states .view (- 1 , hidden_states .shape [- 1 ]))[index ]
691+ cos = (cos .reshape (- 1 , cos .shape [- 1 ]))[index ]
692+ sin = (sin .reshape (- 1 , sin .shape [- 1 ]))[index ]
693+ position_embeddings = (cos .unsqueeze (1 ), sin .unsqueeze (1 ))
694+ else :
695+ hidden_states = hidden_states .view (- 1 , hidden_states .shape [- 1 ])
696+ # TODO: remove this WA after IPEX 2.7
697+ if device .type == "xpu" :
698+ cos = cos .reshape (- 1 , cos .shape [- 1 ])
699+ sin = sin .reshape (- 1 , sin .shape [- 1 ])
700+ position_embeddings = (cos .unsqueeze (1 ), sin .unsqueeze (1 ))
701+ if past_key_values is None :
702+ attention_mask = causal_mask
703+ # part of the code that was modified above
704+
705+ # decoder layers
706+ all_hidden_states = () if output_hidden_states else None
707+ all_self_attns = () if output_attentions else None
708+
709+ for decoder_layer in self .layers [: self .config .num_hidden_layers ]:
710+ if output_hidden_states :
711+ all_hidden_states += (hidden_states ,)
712+
713+ layer_outputs = decoder_layer (
714+ hidden_states ,
715+ attention_mask = attention_mask ,
716+ position_ids = position_ids ,
717+ past_key_value = past_key_values ,
718+ output_attentions = output_attentions ,
719+ use_cache = use_cache ,
720+ cache_position = cache_position ,
721+ position_embeddings = position_embeddings ,
722+ input_lens = input_lens ,
723+ max_input_lens = max_input_lens ,
724+ seq_len_tensor = seq_len_tensor ,
725+ query_len_tensor = query_len_tensor ,
726+ ** kwargs ,
727+ )
728+
729+ hidden_states = layer_outputs [0 ]
730+
731+ if output_attentions :
732+ all_self_attns += (layer_outputs [1 ],)
733+
734+ hidden_states = self .norm (hidden_states )
735+
736+ if hidden_states .shape [0 ] != batch_size * seq_length :
737+ (hidden_states_copy .view (- 1 , hidden_states .shape [- 1 ]))[attention_mask .view (- 1 ) != 0 ] = hidden_states
738+ hidden_states = hidden_states_copy
739+ hidden_states = hidden_states .view (batch_size , - 1 , hidden_states .shape [- 1 ])
740+ # add hidden states from the last decoder layer
741+ if output_hidden_states :
742+ all_hidden_states += (hidden_states ,)
743+
744+ output = BaseModelOutputWithPast (
745+ last_hidden_state = hidden_states ,
746+ past_key_values = past_key_values if use_cache else None ,
747+ hidden_states = all_hidden_states ,
748+ attentions = all_self_attns ,
749+ )
750+ return output if return_dict else output .to_tuple ()
751+
752+
633753class _IPEXAttention (nn .Module ):
634754 def __init__ (self , module , device , config ) -> None :
635755 super ().__init__ ()
@@ -904,7 +1024,8 @@ def __init__(self, module, device, config) -> None:
9041024 # LinearAllreduce and LinearLayer cannot use fused op LinearAdd
9051025 if module .down_proj .__class__ .__name__ not in ["LinearAllreduce" ]:
9061026 self .mlp_linear_add = LinearAdd (module .down_proj )
907- self .linear_silu_mul = Linear2SiluMul (module .gate_proj , module .up_proj )
1027+ if isinstance (self .act_fn , nn .SiLU ):
1028+ self .linear_silu_mul = Linear2SiluMul (module .gate_proj , module .up_proj )
9081029
9091030 def forward (self , hidden_states : torch .Tensor , residual : torch .Tensor = None , ** kwargs ):
9101031 if hasattr (self , "linear_silu_mul" ):
@@ -1136,6 +1257,11 @@ def __init__(self, *args, **kwargs):
11361257 super ().__init__ (* args , ** kwargs )
11371258
11381259
1260+ class _IPEXMistralDecoderLayer (_IPEXLlamaDecoderLayer ):
1261+ def __init__ (self , * args , ** kwargs ):
1262+ super ().__init__ (* args , ** kwargs )
1263+
1264+
11391265# Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/bert/modeling_bert.py#L524
11401266class _IPEXIntermediate (nn .Module ):
11411267 def __init__ (self , module , device , config ):
0 commit comments