diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index d1dbecf77b..0edade1e9a 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -1493,16 +1493,51 @@ def _phi3_self_attn_sdpa_forward( return attn_output, None, past_key_value +def select_ext_factor( + seq_len: torch.Tensor, max_pos_embeddings: torch.Tensor, short_factor: torch.Tensor, long_factor: torch.Tensor +): + return torch.where(seq_len <= max_pos_embeddings, short_factor, long_factor) + + +def long_rope(self, x, position_ids, seq_len=None): + seq_len = torch.max(position_ids) + 1 + original_max_position_embeddings = ( + self.original_max_position_embeddings + if hasattr(self, "original_max_positional_embeddings") + else self.config.original_max_position_embeddings + ) + max_position_embeddings = ( + self.max_position_embeddings + if hasattr(self, "max_position_embeddings") + else self.config.max_position_embeddings + ) + inv_freq = select_ext_factor(seq_len, original_max_position_embeddings, self.inv_freq, self.long_inv_freq) + + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + + scale = max_position_embeddings / original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt(1 + math.log(scale) / math.log(original_max_position_embeddings)) + cos = emb.cos() * scaling_factor + sin = emb.sin() * scaling_factor + return cos, sin + + class Phi3ModelPatcher(OVDecoderModelPatcher): def __enter__(self): super().__enter__() - # currently, long RoPE can not be traced for long context support, disable it for avoid potential accuracy issues - if self._model.config.max_position_embeddings != getattr( - self._model.config, "original_max_position_embeddings", self._model.config.max_position_embeddings - ): - self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings - + # currently, long RoPE can not be traced for long context support, disable it to avoid potential accuracy issues if is_transformers_version("<", "4.48.0"): self._model.model._orig_forward = self._model.model.forward self._model.model.forward = types.MethodType(phi3_442_forward, self._model.model) @@ -1529,6 +1564,23 @@ def __enter__(self): rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim) ) + if ( + hasattr(self._model.model, "rotary_emb") + and getattr(self._model.model.rotary_emb, "rope_type", "default") == "longrope" + ): + long_inv_freq, _ = self._model.model.rotary_emb.rope_init_fn( + self._model.config, + torch.device("cpu"), + seq_len=self._model.config.original_max_position_embeddings + 1, + ) + self._model.model.rotary_emb.long_inv_freq = long_inv_freq + self._model.model.rotary_emb._orig_forward = self._model.model.rotary_emb.forward + self._model.model.rotary_emb.forward = types.MethodType(long_rope, self._model.model.rotary_emb) + elif self._model.config.max_position_embeddings != getattr( + self._model.config, "original_max_position_embeddings", self._model.config.max_position_embeddings + ): + self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings + def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if hasattr(self._model.model, "_orig_forward"): @@ -1538,6 +1590,8 @@ def __exit__(self, exc_type, exc_value, traceback): for layer in self._model.model.layers: if hasattr(layer.self_attn, "_orig_forward"): layer.self_attn.forward = layer.self_attn._orig_forward + if hasattr(self._model.model, "rotary_emb") and hasattr(self._model.model.rotary_emb, "_orig_forward"): + self._model.model.rotary_emb.forward = self._model.model.rotary_emb._orig_forward # Modified from https://github.com/huggingface/transformers/blob/v4.50.2/src/transformers/models/phimoe/modeling_phimoe.py#L756 @@ -1590,13 +1644,47 @@ def _phi_moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torc class PhiMoEModelPatcher(Phi3ModelPatcher): def __enter__(self): - super().__enter__() + # Call OVDecoderModelPatcher.__enter__() directly to skip Phi3ModelPatcher's longrope logic + # PhiMoE has a different rotary embedding structure, longrope is not yet supported + OVDecoderModelPatcher.__enter__(self) + + if is_transformers_version("<", "4.48.0"): + self._model.model._orig_forward = self._model.model.forward + self._model.model.forward = types.MethodType(phi3_442_forward, self._model.model) + + # init inv_freq for torchscript tracing for PhiMoE + for layer in self._model.model.layers: + if ( + is_torch_version(">=", "2.1.0") + and is_transformers_version("<", "4.48.0") + or not getattr(self._model, "_supports_sdpa", False) + ): + orig_self_attn_fwd = layer.self_attn.forward + layer.self_attn.forward = types.MethodType(_phi3_self_attn_sdpa_forward, layer.self_attn) + layer.self_attn._orig_forward = orig_self_attn_fwd + + if ( + hasattr(layer.self_attn, "rotary_emb") + and getattr(layer.self_attn.rotary_emb, "inv_freq", None) is None + ): + rotary_emb = layer.self_attn.rotary_emb + layer.self_attn.rotary_emb.inv_freq = 1.0 / ( + rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim) + ) + + # Apply MoE-specific patching for layer in self._model.model.layers: layer.block_sparse_moe._orig_forward = layer.block_sparse_moe.forward layer.block_sparse_moe.forward = types.MethodType( _phi_moe_sparse_moe_block_forward, layer.block_sparse_moe ) + # For PhiMoE, reset max_position_embeddings if it was extended (skip longrope support for now) + if self._model.config.max_position_embeddings != getattr( + self._model.config, "original_max_position_embeddings", self._model.config.max_position_embeddings + ): + self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings + def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) for layer in self._model.model.layers: diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 1dd6434641..6a002673d2 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -358,11 +358,6 @@ def _export( variant=variant, ) - if config.model_type == "phi3" and config.max_position_embeddings != getattr( - config, "original_max_position_embeddings", config.max_position_embeddings - ): - config.max_position_embeddings = config.original_max_position_embeddings - return cls._from_pretrained( model_id=save_dir_path, config=config, @@ -870,6 +865,8 @@ def _from_pretrained( init_cls = OVBloomForCausalLM elif model_type == "gpt_bigcode": init_cls = OVGPTBigCodeForCausalLM + elif model_type == "phi3": + init_cls = OVPhi3ForCausalLM elif model_type in SSM_MODELS: init_cls = OVModelWithMambaForCausalLM else: @@ -950,6 +947,47 @@ def _from_pretrained( return causal_model +class OVPhi3ForCausalLM(OVModelForCausalLM): + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the + # process + + # When the first time input length reached long and short factor switching point, enforce re-compute cache + # It will cause downside of slower at this single token position, however, better than current failure. + if ( + past_key_values + and self.config.rope_scaling + and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 + ): + past_length = cache_position[0] + if past_length <= self.config.original_max_position_embeddings: + past_key_values = None + + model_inputs = super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + **kwargs, + ) + return model_inputs + + class OVBloomForCausalLM(OVModelForCausalLM): # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):