Skip to content

Commit 72cd3c8

Browse files
committed
update prepare_inputs_for_generation
1 parent 4440904 commit 72cd3c8

File tree

2 files changed

+68
-16
lines changed

2 files changed

+68
-16
lines changed

optimum/exporters/openvino/model_patcher.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1608,25 +1608,28 @@ def _phi3_self_attn_sdpa_forward(
16081608
return attn_output, None, past_key_value
16091609

16101610

1611-
@torch.jit.script
1612-
def select_ext_factor(seq_len: torch.Tensor, max_pos_embeddings: torch.Tensor, short_factor: torch.Tensor, long_factor: torch.Tensor):
1613-
if seq_len > max_pos_embeddings:
1614-
return long_factor
1615-
return short_factor
1611+
# @torch.jit.script
1612+
def select_ext_factor(
1613+
seq_len: torch.Tensor, max_pos_embeddings: torch.Tensor, short_factor: torch.Tensor, long_factor: torch.Tensor
1614+
):
1615+
return torch.where(
1616+
seq_len < max_pos_embeddings, short_factor, long_factor
1617+
) # short_factor * (seq_len <= max_pos_embeddings) + long_factor * (seq_len > max_pos_embeddings)
1618+
16161619

16171620
def long_rope(self, x, position_ids, seq_len=None):
16181621
seq_len = torch.max(position_ids) + 1
16191622
original_max_position_embeddings = (
16201623
self.original_max_position_embeddings
1621-
if hasattr(self, "original_max_positional_embeddings") else self.config.original_max_position_embeddings
1624+
if hasattr(self, "original_max_positional_embeddings")
1625+
else self.config.original_max_position_embeddings
16221626
)
1623-
max_position_embeddings = self.max_position_embeddings if hasattr(self, "max_position_embeddings") else self.config.max_position_embeddings
1624-
inv_freq = select_ext_factor(
1625-
seq_len,
1626-
torch.tensor(original_max_position_embeddings),
1627-
self.inv_freq,
1628-
self.long_inv_freq
1627+
max_position_embeddings = (
1628+
self.max_position_embeddings
1629+
if hasattr(self, "max_position_embeddings")
1630+
else self.config.max_position_embeddings
16291631
)
1632+
inv_freq = select_ext_factor(seq_len, original_max_position_embeddings, self.inv_freq, self.long_inv_freq)
16301633

16311634
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
16321635
position_ids_expanded = position_ids[:, None, :].float()
@@ -1679,9 +1682,16 @@ def __enter__(self):
16791682
layer.self_attn.rotary_emb.inv_freq = 1.0 / (
16801683
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
16811684
)
1682-
1683-
if hasattr(self._model.model, "rotary_emb") and getattr(self._model.model.rotary_emb, "rope_type", "default") == "longrope":
1684-
long_inv_freq, _ = self._model.model.rotary_emb.rope_init_fn(self._model.config, torch.device("cpu"), seq_len=self._model.config.original_max_position_embeddings + 1)
1685+
1686+
if (
1687+
hasattr(self._model.model, "rotary_emb")
1688+
and getattr(self._model.model.rotary_emb, "rope_type", "default") == "longrope"
1689+
):
1690+
long_inv_freq, _ = self._model.model.rotary_emb.rope_init_fn(
1691+
self._model.config,
1692+
torch.device("cpu"),
1693+
seq_len=self._model.config.original_max_position_embeddings + 1,
1694+
)
16851695
self._model.model.rotary_emb.long_inv_freq = long_inv_freq
16861696
self._model.model.rotary_emb._orig_forward = self._model.model.rotary_emb.forward
16871697
self._model.model.rotary_emb.forward = types.MethodType(long_rope, self._model.model.rotary_emb)
@@ -1690,7 +1700,6 @@ def __enter__(self):
16901700
):
16911701
self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings
16921702

1693-
16941703
def __exit__(self, exc_type, exc_value, traceback):
16951704
super().__exit__(exc_type, exc_value, traceback)
16961705
if hasattr(self._model.model, "_orig_forward"):

optimum/intel/openvino/modeling_decoder.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,8 @@ def _from_pretrained(
846846
init_cls = OVBloomForCausalLM
847847
elif model_type == "gpt-bigcode":
848848
init_cls = OVGPTBigCodeForCausalLM
849+
elif model_type == "phi3":
850+
init_cls = OVPhi3ForCausalLM
849851
else:
850852
init_cls = cls
851853

@@ -915,6 +917,47 @@ def _from_pretrained(
915917
return causal_model
916918

917919

920+
class OVPhi3ForCausalLM(OVModelForCausalLM):
921+
def prepare_inputs_for_generation(
922+
self,
923+
input_ids,
924+
past_key_values=None,
925+
attention_mask=None,
926+
inputs_embeds=None,
927+
cache_position=None,
928+
position_ids=None,
929+
use_cache=True,
930+
logits_to_keep=None,
931+
**kwargs,
932+
):
933+
# Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the
934+
# process
935+
936+
# When the first time input length reached long and short factor switching point, enforce re-compute cache
937+
# It will cause downside of slower at this single token position, however, better than current failure.
938+
if (
939+
past_key_values
940+
and self.config.rope_scaling
941+
and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
942+
):
943+
past_length = cache_position[0]
944+
if past_length <= self.config.original_max_position_embeddings:
945+
past_key_values = None
946+
947+
model_inputs = super().prepare_inputs_for_generation(
948+
input_ids=input_ids,
949+
past_key_values=past_key_values,
950+
attention_mask=attention_mask,
951+
inputs_embeds=inputs_embeds,
952+
cache_position=cache_position,
953+
position_ids=position_ids,
954+
use_cache=use_cache,
955+
logits_to_keep=logits_to_keep,
956+
**kwargs,
957+
)
958+
return model_inputs
959+
960+
918961
class OVBloomForCausalLM(OVModelForCausalLM):
919962
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation
920963
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):

0 commit comments

Comments
 (0)