@@ -3581,3 +3581,72 @@ def __exit__(self, exc_type, exc_value, traceback):
35813581 for block in self ._model .blocks :
35823582 block .forward = block ._orig_forward
35833583 block .attn .forward = block .attn ._orig_forward
3584+
3585+
3586+ # copied from https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/granitemoe/modeling_granitemoe.py#L321
3587+ def _granite_moe_topk_gating_forward (self , hidden_states ):
3588+ # compute the top_k routing decision
3589+ logits = self .layer (hidden_states ).float () # [batch_size x seq_len, num_experts]
3590+ top_k_logits , top_k_indices = logits .topk (self .top_k , dim = 1 ) # [num_tokens, top_k]
3591+ top_k_gates = torch .softmax (top_k_logits , dim = 1 ).type_as (hidden_states ) # [num_tokens, top_k]
3592+
3593+ # compute number of input given to each expert
3594+ zeros = torch .zeros (
3595+ [top_k_gates .size (0 ), self .num_experts ], dtype = top_k_gates .dtype , device = top_k_gates .device
3596+ ) # [num_tokens, num_experts]
3597+ gates = zeros .scatter (1 , top_k_indices , 1 ) # [num_tokens, num_experts]
3598+ expert_size = gates .long ().sum (0 ) # [num_experts,]
3599+ # difference with original, removed expert_size = expert_size.tolist() due to incorrect tracing
3600+
3601+ # sort and group input tokens according to expert assignment
3602+ top_k_experts = top_k_indices .flatten () # [num_tokens * top_k]
3603+ _ , index_sorted_experts = top_k_experts .sort (0 ) # [num_tokens * top_k]
3604+ batch_index = index_sorted_experts .div (self .top_k , rounding_mode = "trunc" ) # [num_tokens * top_k]
3605+
3606+ # gather the gate values for grouped input tokens
3607+ top_k_gates = top_k_gates .flatten () # [num_tokens * top_k]
3608+ batch_gates = top_k_gates [index_sorted_experts ] # [num_tokens * top_k]
3609+
3610+ return index_sorted_experts , batch_index , batch_gates , expert_size , logits
3611+
3612+
3613+ # copied from https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/granitemoe/modeling_granitemoe.py#L281
3614+ def _granite_moe_parallel_experts_forward (self , inputs , expert_size ):
3615+ output_list = []
3616+ # difference with original
3617+ # 1) expert_size is tensor instead of list of ints after gating patching, that does not allow use original inputs.split(expert_size)
3618+ # 2) use index_start:next_index for obtaining expert inputs splits one by one instead of precomputed splits once before cycle
3619+ index_start = torch .tensor (0 , dtype = torch .int64 )
3620+ for i in range (self .num_experts ):
3621+ next_index = index_start + expert_size [i ]
3622+ output_list .append (F .linear (inputs [index_start :next_index ], self .weight [i ]))
3623+ index_start = next_index
3624+ results = torch .cat (output_list , dim = 0 )
3625+ return results
3626+
3627+
3628+ class GraniteMoEModelPatcher (LlamaModelPatcher ):
3629+ def __enter__ (self ):
3630+ super ().__enter__ ()
3631+ for layer in self ._model .model .layers :
3632+ block_sparse_moe = layer .block_sparse_moe
3633+ block_sparse_moe .router ._orig_forward = block_sparse_moe .router .forward
3634+ block_sparse_moe .router .forward = types .MethodType (
3635+ _granite_moe_topk_gating_forward , block_sparse_moe .router
3636+ )
3637+ block_sparse_moe .input_linear ._orig_forward = block_sparse_moe .input_linear .forward
3638+ block_sparse_moe .input_linear .forward = types .MethodType (
3639+ _granite_moe_parallel_experts_forward , block_sparse_moe .input_linear
3640+ )
3641+ block_sparse_moe .output_linear ._orig_forward = block_sparse_moe .output_linear .forward
3642+ block_sparse_moe .output_linear .forward = types .MethodType (
3643+ _granite_moe_parallel_experts_forward , block_sparse_moe .output_linear
3644+ )
3645+
3646+ def __exit__ (self , exc_type , exc_value , traceback ):
3647+ super ().__exit__ (exc_type , exc_value , traceback )
3648+ for layer in self ._model .model .layers :
3649+ block_sparse_moe = layer .block_sparse_moe
3650+ block_sparse_moe .router .forward = block_sparse_moe .router ._orig_forward
3651+ block_sparse_moe .input_linear .forward = block_sparse_moe .input_linear ._orig_forward
3652+ block_sparse_moe .output_linear .forward = block_sparse_moe .output_linear ._orig_forward
0 commit comments