Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 97 additions & 5 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,15 +1493,54 @@ def _phi3_self_attn_sdpa_forward(
return attn_output, None, past_key_value


# @torch.jit.script
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to add the test with long prompt. Is this issue reproduced on tiny-model?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove not needed comments and commented out code

def select_ext_factor(
seq_len: torch.Tensor, max_pos_embeddings: torch.Tensor, short_factor: torch.Tensor, long_factor: torch.Tensor
):
return torch.where(
seq_len <= max_pos_embeddings, short_factor, long_factor
) # short_factor * (seq_len <= max_pos_embeddings) + long_factor * (seq_len > max_pos_embeddings)


def long_rope(self, x, position_ids, seq_len=None):
seq_len = torch.max(position_ids) + 1
original_max_position_embeddings = (
self.original_max_position_embeddings
if hasattr(self, "original_max_positional_embeddings")
else self.config.original_max_position_embeddings
)
max_position_embeddings = (
self.max_position_embeddings
if hasattr(self, "max_position_embeddings")
else self.config.max_position_embeddings
)
inv_freq = select_ext_factor(seq_len, original_max_position_embeddings, self.inv_freq, self.long_inv_freq)

inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()

# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)

scale = max_position_embeddings / original_max_position_embeddings
if scale <= 1.0:
scaling_factor = 1.0
else:
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(original_max_position_embeddings))
cos = emb.cos() * scaling_factor
sin = emb.sin() * scaling_factor
return cos, sin


class Phi3ModelPatcher(OVDecoderModelPatcher):
def __enter__(self):
super().__enter__()

# currently, long RoPE can not be traced for long context support, disable it for avoid potential accuracy issues
if self._model.config.max_position_embeddings != getattr(
self._model.config, "original_max_position_embeddings", self._model.config.max_position_embeddings
):
self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings

if is_transformers_version("<", "4.48.0"):
self._model.model._orig_forward = self._model.model.forward
Expand Down Expand Up @@ -1529,6 +1568,23 @@ def __enter__(self):
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
)

if (
hasattr(self._model.model, "rotary_emb")
and getattr(self._model.model.rotary_emb, "rope_type", "default") == "longrope"
):
long_inv_freq, _ = self._model.model.rotary_emb.rope_init_fn(
self._model.config,
torch.device("cpu"),
seq_len=self._model.config.original_max_position_embeddings + 1,
)
self._model.model.rotary_emb.long_inv_freq = long_inv_freq
self._model.model.rotary_emb._orig_forward = self._model.model.rotary_emb.forward
self._model.model.rotary_emb.forward = types.MethodType(long_rope, self._model.model.rotary_emb)
elif self._model.config.max_position_embeddings != getattr(
self._model.config, "original_max_position_embeddings", self._model.config.max_position_embeddings
):
self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model.model, "_orig_forward"):
Expand All @@ -1538,6 +1594,8 @@ def __exit__(self, exc_type, exc_value, traceback):
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward
if hasattr(self._model.model, "rotary_emb") and hasattr(self._model.model.rotary_emb, "_orig_forward"):
self._model.model.rotary_emb.forward = self._model.model.rotary_emb._orig_forward


# Modified from https://github.com/huggingface/transformers/blob/v4.50.2/src/transformers/models/phimoe/modeling_phimoe.py#L756
Expand Down Expand Up @@ -1590,13 +1648,47 @@ def _phi_moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torc

class PhiMoEModelPatcher(Phi3ModelPatcher):
def __enter__(self):
super().__enter__()
# Call OVDecoderModelPatcher.__enter__() directly to skip Phi3ModelPatcher's longrope logic
# PhiMoE has a different rotary embedding structure, longrope is not yet supported
OVDecoderModelPatcher.__enter__(self)

if is_transformers_version("<", "4.48.0"):
self._model.model._orig_forward = self._model.model.forward
self._model.model.forward = types.MethodType(phi3_442_forward, self._model.model)

# init inv_freq for torchscript tracing for PhiMoE
for layer in self._model.model.layers:
if (
is_torch_version(">=", "2.1.0")
and is_transformers_version("<", "4.48.0")
or not getattr(self._model, "_supports_sdpa", False)
):
orig_self_attn_fwd = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(_phi3_self_attn_sdpa_forward, layer.self_attn)
layer.self_attn._orig_forward = orig_self_attn_fwd

if (
hasattr(layer.self_attn, "rotary_emb")
and getattr(layer.self_attn.rotary_emb, "inv_freq", None) is None
):
rotary_emb = layer.self_attn.rotary_emb
layer.self_attn.rotary_emb.inv_freq = 1.0 / (
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
)

# Apply MoE-specific patching
for layer in self._model.model.layers:
layer.block_sparse_moe._orig_forward = layer.block_sparse_moe.forward
layer.block_sparse_moe.forward = types.MethodType(
_phi_moe_sparse_moe_block_forward, layer.block_sparse_moe
)

# For PhiMoE, reset max_position_embeddings if it was extended (skip longrope support for now)
if self._model.config.max_position_embeddings != getattr(
self._model.config, "original_max_position_embeddings", self._model.config.max_position_embeddings
):
self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for layer in self._model.model.layers:
Expand Down
51 changes: 47 additions & 4 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,10 @@ def _export(
variant=variant,
)

if config.model_type == "phi3" and config.max_position_embeddings != getattr(
config, "original_max_position_embeddings", config.max_position_embeddings
):
config.max_position_embeddings = config.original_max_position_embeddings
# if config.model_type == "phi3" and config.max_position_embeddings != getattr(
# config, "original_max_position_embeddings", config.max_position_embeddings
# ):
# config.max_position_embeddings = config.original_max_position_embeddings
Comment on lines +361 to +364
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be deleted but I left it here for now for context.


return cls._from_pretrained(
model_id=save_dir_path,
Expand Down Expand Up @@ -870,6 +870,8 @@ def _from_pretrained(
init_cls = OVBloomForCausalLM
elif model_type == "gpt_bigcode":
init_cls = OVGPTBigCodeForCausalLM
elif model_type == "phi3":
init_cls = OVPhi3ForCausalLM
elif model_type in SSM_MODELS:
init_cls = OVModelWithMambaForCausalLM
else:
Expand Down Expand Up @@ -950,6 +952,47 @@ def _from_pretrained(
return causal_model


class OVPhi3ForCausalLM(OVModelForCausalLM):
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
logits_to_keep=None,
**kwargs,
):
# Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the
# process

# When the first time input length reached long and short factor switching point, enforce re-compute cache
# It will cause downside of slower at this single token position, however, better than current failure.
if (
past_key_values
and self.config.rope_scaling
and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
):
past_length = cache_position[0]
if past_length <= self.config.original_max_position_embeddings:
past_key_values = None

model_inputs = super().prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
**kwargs,
)
return model_inputs


class OVBloomForCausalLM(OVModelForCausalLM):
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
Expand Down
Loading