Skip to content

Commit 4440904

Browse files
committed
test longrope phi4
1 parent 560b980 commit 4440904

File tree

2 files changed

+57
-8
lines changed

2 files changed

+57
-8
lines changed

optimum/exporters/openvino/model_patcher.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1608,15 +1608,51 @@ def _phi3_self_attn_sdpa_forward(
16081608
return attn_output, None, past_key_value
16091609

16101610

1611+
@torch.jit.script
1612+
def select_ext_factor(seq_len: torch.Tensor, max_pos_embeddings: torch.Tensor, short_factor: torch.Tensor, long_factor: torch.Tensor):
1613+
if seq_len > max_pos_embeddings:
1614+
return long_factor
1615+
return short_factor
1616+
1617+
def long_rope(self, x, position_ids, seq_len=None):
1618+
seq_len = torch.max(position_ids) + 1
1619+
original_max_position_embeddings = (
1620+
self.original_max_position_embeddings
1621+
if hasattr(self, "original_max_positional_embeddings") else self.config.original_max_position_embeddings
1622+
)
1623+
max_position_embeddings = self.max_position_embeddings if hasattr(self, "max_position_embeddings") else self.config.max_position_embeddings
1624+
inv_freq = select_ext_factor(
1625+
seq_len,
1626+
torch.tensor(original_max_position_embeddings),
1627+
self.inv_freq,
1628+
self.long_inv_freq
1629+
)
1630+
1631+
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
1632+
position_ids_expanded = position_ids[:, None, :].float()
1633+
1634+
# Force float32 since bfloat16 loses precision on long contexts
1635+
# See https://github.com/huggingface/transformers/pull/29285
1636+
device_type = x.device.type
1637+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
1638+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
1639+
emb = torch.cat((freqs, freqs), dim=-1)
1640+
1641+
scale = max_position_embeddings / original_max_position_embeddings
1642+
if scale <= 1.0:
1643+
scaling_factor = 1.0
1644+
else:
1645+
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(original_max_position_embeddings))
1646+
cos = emb.cos() * scaling_factor
1647+
sin = emb.sin() * scaling_factor
1648+
return cos, sin
1649+
1650+
16111651
class Phi3ModelPatcher(DecoderModelPatcher):
16121652
def __enter__(self):
16131653
super().__enter__()
16141654

16151655
# currently, long RoPE can not be traced for long context support, disable it for avoid potential accuracy issues
1616-
if self._model.config.max_position_embeddings != getattr(
1617-
self._model.config, "original_max_position_embeddings", self._model.config.max_position_embeddings
1618-
):
1619-
self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings
16201656

16211657
if is_transformers_version(">=", "4.42.0") and is_transformers_version("<", "4.48.0"):
16221658
self._model.model._orig_forward = self._model.model.forward
@@ -1643,6 +1679,17 @@ def __enter__(self):
16431679
layer.self_attn.rotary_emb.inv_freq = 1.0 / (
16441680
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
16451681
)
1682+
1683+
if hasattr(self._model.model, "rotary_emb") and getattr(self._model.model.rotary_emb, "rope_type", "default") == "longrope":
1684+
long_inv_freq, _ = self._model.model.rotary_emb.rope_init_fn(self._model.config, torch.device("cpu"), seq_len=self._model.config.original_max_position_embeddings + 1)
1685+
self._model.model.rotary_emb.long_inv_freq = long_inv_freq
1686+
self._model.model.rotary_emb._orig_forward = self._model.model.rotary_emb.forward
1687+
self._model.model.rotary_emb.forward = types.MethodType(long_rope, self._model.model.rotary_emb)
1688+
elif self._model.config.max_position_embeddings != getattr(
1689+
self._model.config, "original_max_position_embeddings", self._model.config.max_position_embeddings
1690+
):
1691+
self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings
1692+
16461693

16471694
def __exit__(self, exc_type, exc_value, traceback):
16481695
super().__exit__(exc_type, exc_value, traceback)
@@ -1653,6 +1700,8 @@ def __exit__(self, exc_type, exc_value, traceback):
16531700
for layer in self._model.model.layers:
16541701
if hasattr(layer.self_attn, "_orig_forward"):
16551702
layer.self_attn.forward = layer.self_attn._orig_forward
1703+
if hasattr(self._model.model, "rotary_emb") and hasattr(self._model.model.rotary_emb, "_orig_forward"):
1704+
self._model.model.rotary_emb.forward = self._model.model.rotary_emb._orig_forward
16561705

16571706

16581707
# Modified from https://github.com/huggingface/transformers/blob/v4.50.2/src/transformers/models/phimoe/modeling_phimoe.py#L756

optimum/intel/openvino/modeling_decoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -341,10 +341,10 @@ def _from_transformers(
341341
variant=variant,
342342
)
343343

344-
if config.model_type == "phi3" and config.max_position_embeddings != getattr(
345-
config, "original_max_position_embeddings", config.max_position_embeddings
346-
):
347-
config.max_position_embeddings = config.original_max_position_embeddings
344+
# if config.model_type == "phi3" and config.max_position_embeddings != getattr(
345+
# config, "original_max_position_embeddings", config.max_position_embeddings
346+
# ):
347+
# config.max_position_embeddings = config.original_max_position_embeddings
348348

349349
return cls._from_pretrained(
350350
model_id=save_dir_path,

0 commit comments

Comments
 (0)