Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 62 additions & 4 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1608,15 +1608,54 @@ def _phi3_self_attn_sdpa_forward(
return attn_output, None, past_key_value


# @torch.jit.script
def select_ext_factor(
seq_len: torch.Tensor, max_pos_embeddings: torch.Tensor, short_factor: torch.Tensor, long_factor: torch.Tensor
):
return torch.where(
seq_len < max_pos_embeddings, short_factor, long_factor
) # short_factor * (seq_len <= max_pos_embeddings) + long_factor * (seq_len > max_pos_embeddings)


def long_rope(self, x, position_ids, seq_len=None):
seq_len = torch.max(position_ids) + 1
original_max_position_embeddings = (
self.original_max_position_embeddings
if hasattr(self, "original_max_positional_embeddings")
else self.config.original_max_position_embeddings
)
max_position_embeddings = (
self.max_position_embeddings
if hasattr(self, "max_position_embeddings")
else self.config.max_position_embeddings
)
inv_freq = select_ext_factor(seq_len, original_max_position_embeddings, self.inv_freq, self.long_inv_freq)

inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()

# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)

scale = max_position_embeddings / original_max_position_embeddings
if scale <= 1.0:
scaling_factor = 1.0
else:
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(original_max_position_embeddings))
cos = emb.cos() * scaling_factor
sin = emb.sin() * scaling_factor
return cos, sin


class Phi3ModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()

# currently, long RoPE can not be traced for long context support, disable it for avoid potential accuracy issues
if self._model.config.max_position_embeddings != getattr(
self._model.config, "original_max_position_embeddings", self._model.config.max_position_embeddings
):
self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings

if is_transformers_version(">=", "4.42.0") and is_transformers_version("<", "4.48.0"):
self._model.model._orig_forward = self._model.model.forward
Expand Down Expand Up @@ -1644,6 +1683,23 @@ def __enter__(self):
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
)

if (
hasattr(self._model.model, "rotary_emb")
and getattr(self._model.model.rotary_emb, "rope_type", "default") == "longrope"
):
long_inv_freq, _ = self._model.model.rotary_emb.rope_init_fn(
self._model.config,
torch.device("cpu"),
seq_len=self._model.config.original_max_position_embeddings + 1,
)
self._model.model.rotary_emb.long_inv_freq = long_inv_freq
self._model.model.rotary_emb._orig_forward = self._model.model.rotary_emb.forward
self._model.model.rotary_emb.forward = types.MethodType(long_rope, self._model.model.rotary_emb)
elif self._model.config.max_position_embeddings != getattr(
self._model.config, "original_max_position_embeddings", self._model.config.max_position_embeddings
):
self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model.model, "_orig_forward"):
Expand All @@ -1653,6 +1709,8 @@ def __exit__(self, exc_type, exc_value, traceback):
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward
if hasattr(self._model.model, "rotary_emb") and hasattr(self._model.model.rotary_emb, "_orig_forward"):
self._model.model.rotary_emb.forward = self._model.model.rotary_emb._orig_forward


# Modified from https://github.com/huggingface/transformers/blob/v4.50.2/src/transformers/models/phimoe/modeling_phimoe.py#L756
Expand Down
51 changes: 47 additions & 4 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,10 +341,10 @@ def _from_transformers(
variant=variant,
)

if config.model_type == "phi3" and config.max_position_embeddings != getattr(
config, "original_max_position_embeddings", config.max_position_embeddings
):
config.max_position_embeddings = config.original_max_position_embeddings
# if config.model_type == "phi3" and config.max_position_embeddings != getattr(
# config, "original_max_position_embeddings", config.max_position_embeddings
# ):
# config.max_position_embeddings = config.original_max_position_embeddings

return cls._from_pretrained(
model_id=save_dir_path,
Expand Down Expand Up @@ -846,6 +846,8 @@ def _from_pretrained(
init_cls = OVBloomForCausalLM
elif model_type == "gpt-bigcode":
init_cls = OVGPTBigCodeForCausalLM
elif model_type == "phi3":
init_cls = OVPhi3ForCausalLM
else:
init_cls = cls

Expand Down Expand Up @@ -915,6 +917,47 @@ def _from_pretrained(
return causal_model


class OVPhi3ForCausalLM(OVModelForCausalLM):
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
logits_to_keep=None,
**kwargs,
):
# Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the
# process

# When the first time input length reached long and short factor switching point, enforce re-compute cache
# It will cause downside of slower at this single token position, however, better than current failure.
if (
past_key_values
and self.config.rope_scaling
and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
):
past_length = cache_position[0]
if past_length <= self.config.original_max_position_embeddings:
past_key_values = None

model_inputs = super().prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
**kwargs,
)
return model_inputs


class OVBloomForCausalLM(OVModelForCausalLM):
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
Expand Down
Loading