@@ -1598,7 +1598,10 @@ def __enter__(self):
15981598 layer .self_attn .forward = types .MethodType (_phi3_self_attn_sdpa_forward , layer .self_attn )
15991599 layer .self_attn ._orig_forward = orig_self_attn_fwd
16001600
1601- if hasattr (layer .self_attn , "rotary_emb" ) and layer .self_attn .rotary_emb .inv_freq is None :
1601+ if (
1602+ hasattr (layer .self_attn , "rotary_emb" )
1603+ and getattr (layer .self_attn .rotary_emb , "inv_freq" , None ) is None
1604+ ):
16021605 rotary_emb = layer .self_attn .rotary_emb
16031606 layer .self_attn .rotary_emb .inv_freq = 1.0 / (
16041607 rotary_emb .base ** (torch .arange (0 , rotary_emb .dim , 2 , dtype = torch .int64 ).float () / rotary_emb .dim )
@@ -1615,6 +1618,69 @@ def __exit__(self, exc_type, exc_value, traceback):
16151618 layer .self_attn .forward = layer .self_attn ._orig_forward
16161619
16171620
1621+ # Modified from https://github.com/huggingface/transformers/blob/v4.50.2/src/transformers/models/phimoe/modeling_phimoe.py#L756
1622+ # removed usage nonfriendly for tracing operation continue
1623+ def _phi_moe_sparse_moe_block_forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
1624+ from transformers .models .phimoe .modeling_phimoe import sparsemixer
1625+
1626+ batch_size , sequence_length , hidden_dim = hidden_states .shape
1627+ if self .training and self .input_jitter_noise > 0 :
1628+ hidden_states *= torch .empty_like (hidden_states ).uniform_ (
1629+ 1.0 - self .input_jitter_noise , 1.0 + self .input_jitter_noise
1630+ )
1631+ hidden_states = hidden_states .view (- 1 , hidden_dim )
1632+ router_logits = self .gate (hidden_states )
1633+
1634+ routing_weights , selected_experts = sparsemixer (
1635+ router_logits ,
1636+ jitter_eps = self .router_jitter_noise ,
1637+ training = self .training ,
1638+ )
1639+
1640+ final_hidden_states = torch .zeros (
1641+ (batch_size * sequence_length , hidden_dim ), dtype = hidden_states .dtype , device = hidden_states .device
1642+ )
1643+
1644+ # One hot encode the selected experts to create an expert mask
1645+ # this will be used to easily index which expert is going to be sollicitated
1646+ expert_mask = torch .nn .functional .one_hot (selected_experts , num_classes = self .num_experts ).permute (2 , 1 , 0 )
1647+
1648+ # Loop over all available experts in the model and perform the computation on each expert
1649+ for expert_idx in range (self .num_experts ):
1650+ expert_layer = self .experts [expert_idx ]
1651+ idx , top_x = torch .where (expert_mask [expert_idx ])
1652+
1653+ # if top_x.shape[0] == 0:
1654+ # continue
1655+
1656+ # Index the correct hidden states and compute the expert hidden state for
1657+ # the current expert. We need to make sure to multiply the output hidden
1658+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
1659+ current_state = hidden_states [None , top_x ].reshape (- 1 , hidden_dim )
1660+ current_hidden_states = expert_layer (current_state ) * routing_weights [top_x , idx , None ]
1661+
1662+ # However `index_add_` only support torch tensors for indexing so we'll use
1663+ # the `top_x` tensor here.
1664+ final_hidden_states .index_add_ (0 , top_x , current_hidden_states .to (hidden_states .dtype ))
1665+ final_hidden_states = final_hidden_states .reshape (batch_size , sequence_length , hidden_dim )
1666+ return final_hidden_states , router_logits
1667+
1668+
1669+ class PhiMoEModelPatcher (Phi3ModelPatcher ):
1670+ def __enter__ (self ):
1671+ super ().__enter__ ()
1672+ for layer in self ._model .model .layers :
1673+ layer .block_sparse_moe ._orig_forward = layer .block_sparse_moe .forward
1674+ layer .block_sparse_moe .forward = types .MethodType (
1675+ _phi_moe_sparse_moe_block_forward , layer .block_sparse_moe
1676+ )
1677+
1678+ def __exit__ (self , exc_type , exc_value , traceback ):
1679+ super ().__exit__ (exc_type , exc_value , traceback )
1680+ for layer in self ._model .model .layers :
1681+ layer .block_sparse_moe .forward = layer .block_sparse_moe ._orig_forward
1682+
1683+
16181684def _aquila_self_attn_sdpa_forward (
16191685 self ,
16201686 hidden_states : torch .Tensor ,
0 commit comments