Skip to content

Commit fc76020

Browse files
authored
phi3moe support (#1215)
* phi3moe support * add tests * use transformers code
1 parent 6a5a01e commit fc76020

File tree

5 files changed

+91
-2
lines changed

5 files changed

+91
-2
lines changed

docs/source/openvino/models.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ Here is the list of the supported architectures :
103103
- Persimmon
104104
- Phi
105105
- Phi3
106+
- Phi3.5-MoE
106107
- Phi3Vision
107108
- Pix2Struct
108109
- PoolFormer

optimum/exporters/openvino/model_configs.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@
109109
PersimmonModelPatcher,
110110
Phi3ModelPatcher,
111111
Phi3VisionImageEmbeddingsPatcher,
112+
PhiMoEModelPatcher,
112113
Qwen2_5_VLVisionEmbMergerPatcher,
113114
Qwen2VLLanguageModelPatcher,
114115
Qwen2VLVisionEmbMergerPatcher,
@@ -737,6 +738,26 @@ def patch_model_for_export(
737738
return Phi3ModelPatcher(self, model, model_kwargs=model_kwargs)
738739

739740

741+
@register_in_tasks_manager(
742+
"phimoe",
743+
*[
744+
"feature-extraction",
745+
"feature-extraction-with-past",
746+
"text-generation",
747+
"text-generation-with-past",
748+
"text-classification",
749+
],
750+
library_name="transformers",
751+
)
752+
class PhiMoEOpenVINOConfig(Phi3OpenVINOConfig):
753+
MIN_TRANSFORMERS_VERSION = "4.46.0"
754+
755+
def patch_model_for_export(
756+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
757+
) -> "ModelPatcher":
758+
return PhiMoEModelPatcher(self, model, model_kwargs=model_kwargs)
759+
760+
740761
@register_in_tasks_manager(
741762
"phi",
742763
*[

optimum/exporters/openvino/model_patcher.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1598,7 +1598,10 @@ def __enter__(self):
15981598
layer.self_attn.forward = types.MethodType(_phi3_self_attn_sdpa_forward, layer.self_attn)
15991599
layer.self_attn._orig_forward = orig_self_attn_fwd
16001600

1601-
if hasattr(layer.self_attn, "rotary_emb") and layer.self_attn.rotary_emb.inv_freq is None:
1601+
if (
1602+
hasattr(layer.self_attn, "rotary_emb")
1603+
and getattr(layer.self_attn.rotary_emb, "inv_freq", None) is None
1604+
):
16021605
rotary_emb = layer.self_attn.rotary_emb
16031606
layer.self_attn.rotary_emb.inv_freq = 1.0 / (
16041607
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
@@ -1615,6 +1618,69 @@ def __exit__(self, exc_type, exc_value, traceback):
16151618
layer.self_attn.forward = layer.self_attn._orig_forward
16161619

16171620

1621+
# Modified from https://github.com/huggingface/transformers/blob/v4.50.2/src/transformers/models/phimoe/modeling_phimoe.py#L756
1622+
# removed usage nonfriendly for tracing operation continue
1623+
def _phi_moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1624+
from transformers.models.phimoe.modeling_phimoe import sparsemixer
1625+
1626+
batch_size, sequence_length, hidden_dim = hidden_states.shape
1627+
if self.training and self.input_jitter_noise > 0:
1628+
hidden_states *= torch.empty_like(hidden_states).uniform_(
1629+
1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise
1630+
)
1631+
hidden_states = hidden_states.view(-1, hidden_dim)
1632+
router_logits = self.gate(hidden_states)
1633+
1634+
routing_weights, selected_experts = sparsemixer(
1635+
router_logits,
1636+
jitter_eps=self.router_jitter_noise,
1637+
training=self.training,
1638+
)
1639+
1640+
final_hidden_states = torch.zeros(
1641+
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
1642+
)
1643+
1644+
# One hot encode the selected experts to create an expert mask
1645+
# this will be used to easily index which expert is going to be sollicitated
1646+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
1647+
1648+
# Loop over all available experts in the model and perform the computation on each expert
1649+
for expert_idx in range(self.num_experts):
1650+
expert_layer = self.experts[expert_idx]
1651+
idx, top_x = torch.where(expert_mask[expert_idx])
1652+
1653+
# if top_x.shape[0] == 0:
1654+
# continue
1655+
1656+
# Index the correct hidden states and compute the expert hidden state for
1657+
# the current expert. We need to make sure to multiply the output hidden
1658+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
1659+
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
1660+
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
1661+
1662+
# However `index_add_` only support torch tensors for indexing so we'll use
1663+
# the `top_x` tensor here.
1664+
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
1665+
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
1666+
return final_hidden_states, router_logits
1667+
1668+
1669+
class PhiMoEModelPatcher(Phi3ModelPatcher):
1670+
def __enter__(self):
1671+
super().__enter__()
1672+
for layer in self._model.model.layers:
1673+
layer.block_sparse_moe._orig_forward = layer.block_sparse_moe.forward
1674+
layer.block_sparse_moe.forward = types.MethodType(
1675+
_phi_moe_sparse_moe_block_forward, layer.block_sparse_moe
1676+
)
1677+
1678+
def __exit__(self, exc_type, exc_value, traceback):
1679+
super().__exit__(exc_type, exc_value, traceback)
1680+
for layer in self._model.model.layers:
1681+
layer.block_sparse_moe.forward = layer.block_sparse_moe._orig_forward
1682+
1683+
16181684
def _aquila_self_attn_sdpa_forward(
16191685
self,
16201686
hidden_states: torch.Tensor,

tests/openvino/test_modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1032,7 +1032,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
10321032
SUPPORTED_ARCHITECTURES += ("granite", "granite-moe")
10331033

10341034
if is_transformers_version(">=", "4.46.0"):
1035-
SUPPORTED_ARCHITECTURES += ("glm", "mistral-nemo", "minicpm3")
1035+
SUPPORTED_ARCHITECTURES += ("glm", "mistral-nemo", "minicpm3", "phi3-moe")
10361036
# openvino 2025.0 required for disabling check_trace
10371037
if is_openvino_version(">=", "2025.0"):
10381038
SUPPORTED_ARCHITECTURES += ("deepseek",)

tests/openvino/utils_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@
126126
"pix2struct": "fxmarty/pix2struct-tiny-random",
127127
"phi": "echarlaix/tiny-random-PhiForCausalLM",
128128
"phi3": "Xenova/tiny-random-Phi3ForCausalLM",
129+
"phi3-moe": "katuni4ka/phi-3.5-moe-tiny-random",
129130
"phi3_v": "katuni4ka/tiny-random-phi3-vision",
130131
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
131132
"qwen": "katuni4ka/tiny-random-qwen",

0 commit comments

Comments
 (0)