Skip to content
102 changes: 97 additions & 5 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,15 +1493,54 @@ def _phi3_self_attn_sdpa_forward(
return attn_output, None, past_key_value


# @torch.jit.script
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to add the test with long prompt. Is this issue reproduced on tiny-model?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove not needed comments and commented out code

def select_ext_factor(
seq_len: torch.Tensor, max_pos_embeddings: torch.Tensor, short_factor: torch.Tensor, long_factor: torch.Tensor
):
return torch.where(
seq_len <= max_pos_embeddings, short_factor, long_factor
) # short_factor * (seq_len <= max_pos_embeddings) + long_factor * (seq_len > max_pos_embeddings)


def long_rope(self, x, position_ids, seq_len=None):
seq_len = torch.max(position_ids) + 1
original_max_position_embeddings = (
self.original_max_position_embeddings
if hasattr(self, "original_max_positional_embeddings")
else self.config.original_max_position_embeddings
)
max_position_embeddings = (
self.max_position_embeddings
if hasattr(self, "max_position_embeddings")
else self.config.max_position_embeddings
)
inv_freq = select_ext_factor(seq_len, original_max_position_embeddings, self.inv_freq, self.long_inv_freq)

inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()

# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)

scale = max_position_embeddings / original_max_position_embeddings
if scale <= 1.0:
scaling_factor = 1.0
else:
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(original_max_position_embeddings))
cos = emb.cos() * scaling_factor
sin = emb.sin() * scaling_factor
return cos, sin


class Phi3ModelPatcher(OVDecoderModelPatcher):
def __enter__(self):
super().__enter__()

# currently, long RoPE can not be traced for long context support, disable it for avoid potential accuracy issues
if self._model.config.max_position_embeddings != getattr(
self._model.config, "original_max_position_embeddings", self._model.config.max_position_embeddings
):
self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings

if is_transformers_version("<", "4.48.0"):
self._model.model._orig_forward = self._model.model.forward
Expand Down Expand Up @@ -1529,6 +1568,23 @@ def __enter__(self):
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
)

if (
hasattr(self._model.model, "rotary_emb")
and getattr(self._model.model.rotary_emb, "rope_type", "default") == "longrope"
):
long_inv_freq, _ = self._model.model.rotary_emb.rope_init_fn(
self._model.config,
torch.device("cpu"),
seq_len=self._model.config.original_max_position_embeddings + 1,
)
self._model.model.rotary_emb.long_inv_freq = long_inv_freq
self._model.model.rotary_emb._orig_forward = self._model.model.rotary_emb.forward
self._model.model.rotary_emb.forward = types.MethodType(long_rope, self._model.model.rotary_emb)
elif self._model.config.max_position_embeddings != getattr(
self._model.config, "original_max_position_embeddings", self._model.config.max_position_embeddings
):
self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model.model, "_orig_forward"):
Expand All @@ -1538,6 +1594,8 @@ def __exit__(self, exc_type, exc_value, traceback):
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward
if hasattr(self._model.model, "rotary_emb") and hasattr(self._model.model.rotary_emb, "_orig_forward"):
self._model.model.rotary_emb.forward = self._model.model.rotary_emb._orig_forward


# Modified from https://github.com/huggingface/transformers/blob/v4.50.2/src/transformers/models/phimoe/modeling_phimoe.py#L756
Expand Down Expand Up @@ -1590,13 +1648,47 @@ def _phi_moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torc

class PhiMoEModelPatcher(Phi3ModelPatcher):
def __enter__(self):
super().__enter__()
# Call OVDecoderModelPatcher.__enter__() directly to skip Phi3ModelPatcher's longrope logic
# PhiMoE has a different rotary embedding structure, longrope is not yet supported
OVDecoderModelPatcher.__enter__(self)

if is_transformers_version("<", "4.48.0"):
self._model.model._orig_forward = self._model.model.forward
self._model.model.forward = types.MethodType(phi3_442_forward, self._model.model)

# init inv_freq for torchscript tracing for PhiMoE
for layer in self._model.model.layers:
if (
is_torch_version(">=", "2.1.0")
and is_transformers_version("<", "4.48.0")
or not getattr(self._model, "_supports_sdpa", False)
):
orig_self_attn_fwd = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(_phi3_self_attn_sdpa_forward, layer.self_attn)
layer.self_attn._orig_forward = orig_self_attn_fwd

if (
hasattr(layer.self_attn, "rotary_emb")
and getattr(layer.self_attn.rotary_emb, "inv_freq", None) is None
):
rotary_emb = layer.self_attn.rotary_emb
layer.self_attn.rotary_emb.inv_freq = 1.0 / (
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
)

# Apply MoE-specific patching
for layer in self._model.model.layers:
layer.block_sparse_moe._orig_forward = layer.block_sparse_moe.forward
layer.block_sparse_moe.forward = types.MethodType(
_phi_moe_sparse_moe_block_forward, layer.block_sparse_moe
)

# For PhiMoE, reset max_position_embeddings if it was extended (skip longrope support for now)
if self._model.config.max_position_embeddings != getattr(
self._model.config, "original_max_position_embeddings", self._model.config.max_position_embeddings
):
self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for layer in self._model.model.layers:
Expand Down
51 changes: 47 additions & 4 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,10 @@ def _export(
variant=variant,
)

if config.model_type == "phi3" and config.max_position_embeddings != getattr(
config, "original_max_position_embeddings", config.max_position_embeddings
):
config.max_position_embeddings = config.original_max_position_embeddings
# if config.model_type == "phi3" and config.max_position_embeddings != getattr(
# config, "original_max_position_embeddings", config.max_position_embeddings
# ):
# config.max_position_embeddings = config.original_max_position_embeddings
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be deleted but I left it here for now for context.


return cls._from_pretrained(
model_id=save_dir_path,
Expand Down Expand Up @@ -870,6 +870,8 @@ def _from_pretrained(
init_cls = OVBloomForCausalLM
elif model_type == "gpt_bigcode":
init_cls = OVGPTBigCodeForCausalLM
elif model_type == "phi3":
init_cls = OVPhi3ForCausalLM
elif model_type in SSM_MODELS:
init_cls = OVModelWithMambaForCausalLM
else:
Expand Down Expand Up @@ -950,6 +952,47 @@ def _from_pretrained(
return causal_model


class OVPhi3ForCausalLM(OVModelForCausalLM):
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
logits_to_keep=None,
**kwargs,
):
# Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the
# process

# When the first time input length reached long and short factor switching point, enforce re-compute cache
# It will cause downside of slower at this single token position, however, better than current failure.
if (
past_key_values
and self.config.rope_scaling
and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
):
past_length = cache_position[0]
if past_length <= self.config.original_max_position_embeddings:
past_key_values = None

model_inputs = super().prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
**kwargs,
)
return model_inputs


class OVBloomForCausalLM(OVModelForCausalLM):
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
Expand Down
Loading