diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 0ffb612d6b..e4df1733cd 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -1608,15 +1608,54 @@ def _phi3_self_attn_sdpa_forward( return attn_output, None, past_key_value +# @torch.jit.script +def select_ext_factor( + seq_len: torch.Tensor, max_pos_embeddings: torch.Tensor, short_factor: torch.Tensor, long_factor: torch.Tensor +): + return torch.where( + seq_len <= max_pos_embeddings, short_factor, long_factor + ) # short_factor * (seq_len <= max_pos_embeddings) + long_factor * (seq_len > max_pos_embeddings) + + +def long_rope(self, x, position_ids, seq_len=None): + seq_len = torch.max(position_ids) + 1 + original_max_position_embeddings = ( + self.original_max_position_embeddings + if hasattr(self, "original_max_positional_embeddings") + else self.config.original_max_position_embeddings + ) + max_position_embeddings = ( + self.max_position_embeddings + if hasattr(self, "max_position_embeddings") + else self.config.max_position_embeddings + ) + inv_freq = select_ext_factor(seq_len, original_max_position_embeddings, self.inv_freq, self.long_inv_freq) + + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + + scale = max_position_embeddings / original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt(1 + math.log(scale) / math.log(original_max_position_embeddings)) + cos = emb.cos() * scaling_factor + sin = emb.sin() * scaling_factor + return cos, sin + + class Phi3ModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() # currently, long RoPE can not be traced for long context support, disable it for avoid potential accuracy issues - if self._model.config.max_position_embeddings != getattr( - self._model.config, "original_max_position_embeddings", self._model.config.max_position_embeddings - ): - self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings if is_transformers_version(">=", "4.42.0") and is_transformers_version("<", "4.48.0"): self._model.model._orig_forward = self._model.model.forward @@ -1644,6 +1683,23 @@ def __enter__(self): rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim) ) + if ( + hasattr(self._model.model, "rotary_emb") + and getattr(self._model.model.rotary_emb, "rope_type", "default") == "longrope" + ): + long_inv_freq, _ = self._model.model.rotary_emb.rope_init_fn( + self._model.config, + torch.device("cpu"), + seq_len=self._model.config.original_max_position_embeddings + 1, + ) + self._model.model.rotary_emb.long_inv_freq = long_inv_freq + self._model.model.rotary_emb._orig_forward = self._model.model.rotary_emb.forward + self._model.model.rotary_emb.forward = types.MethodType(long_rope, self._model.model.rotary_emb) + elif self._model.config.max_position_embeddings != getattr( + self._model.config, "original_max_position_embeddings", self._model.config.max_position_embeddings + ): + self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings + def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if hasattr(self._model.model, "_orig_forward"): @@ -1653,6 +1709,8 @@ def __exit__(self, exc_type, exc_value, traceback): for layer in self._model.model.layers: if hasattr(layer.self_attn, "_orig_forward"): layer.self_attn.forward = layer.self_attn._orig_forward + if hasattr(self._model.model, "rotary_emb") and hasattr(self._model.model.rotary_emb, "_orig_forward"): + self._model.model.rotary_emb.forward = self._model.model.rotary_emb._orig_forward # Modified from https://github.com/huggingface/transformers/blob/v4.50.2/src/transformers/models/phimoe/modeling_phimoe.py#L756 diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index cafa678288..e71f860fcd 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -362,10 +362,10 @@ def _export( variant=variant, ) - if config.model_type == "phi3" and config.max_position_embeddings != getattr( - config, "original_max_position_embeddings", config.max_position_embeddings - ): - config.max_position_embeddings = config.original_max_position_embeddings + # if config.model_type == "phi3" and config.max_position_embeddings != getattr( + # config, "original_max_position_embeddings", config.max_position_embeddings + # ): + # config.max_position_embeddings = config.original_max_position_embeddings return cls._from_pretrained( model_id=save_dir_path, @@ -870,6 +870,8 @@ def _from_pretrained( init_cls = OVBloomForCausalLM elif model_type == "gpt_bigcode": init_cls = OVGPTBigCodeForCausalLM + elif model_type == "phi3": + init_cls = OVPhi3ForCausalLM elif model_type in SSM_MODELS: init_cls = OVMambaForCausalLM else: @@ -943,6 +945,47 @@ def _from_pretrained( return causal_model +class OVPhi3ForCausalLM(OVModelForCausalLM): + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the + # process + + # When the first time input length reached long and short factor switching point, enforce re-compute cache + # It will cause downside of slower at this single token position, however, better than current failure. + if ( + past_key_values + and self.config.rope_scaling + and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 + ): + past_length = cache_position[0] + if past_length <= self.config.original_max_position_embeddings: + past_key_values = None + + model_inputs = super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + **kwargs, + ) + return model_inputs + + class OVBloomForCausalLM(OVModelForCausalLM): # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):