@@ -8836,6 +8836,75 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
88368836 return [(self .map_tensor_name (name ), data_torch )]
88378837
88388838
8839+ @ModelBase .register ("Lfm2MoeForCausalLM" )
8840+ class LFM2MoeModel (TextModel ):
8841+ model_arch = gguf .MODEL_ARCH .LFM2MOE
8842+
8843+ def set_gguf_parameters (self ):
8844+ # set num_key_value_heads only for attention layers
8845+ self .hparams ["num_key_value_heads" ] = [
8846+ self .hparams ["num_key_value_heads" ] if layer_type == "full_attention" else 0
8847+ for layer_type in self .hparams ["layer_types" ]
8848+ ]
8849+
8850+ super ().set_gguf_parameters ()
8851+
8852+ self .gguf_writer .add_expert_count (self .hparams ["num_experts" ])
8853+ self .gguf_writer .add_expert_feed_forward_length (self .hparams ["moe_intermediate_size" ])
8854+ self .gguf_writer .add_leading_dense_block_count (self .hparams ["num_dense_layers" ])
8855+ self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SIGMOID )
8856+
8857+ self .gguf_writer .add_vocab_size (self .hparams ["vocab_size" ])
8858+ self .gguf_writer .add_shortconv_l_cache (self .hparams ["conv_L_cache" ])
8859+
8860+ # cache for experts weights for merging
8861+ _experts_cache : dict [int , dict [str , Tensor ]] = {}
8862+
8863+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
8864+ # conv op requires 2d tensor
8865+ if 'conv.conv' in name :
8866+ data_torch = data_torch .squeeze (1 )
8867+
8868+ if name .endswith (".expert_bias" ):
8869+ name = name .replace (".expert_bias" , ".expert_bias.bias" )
8870+
8871+ # merge expert weights
8872+ if 'experts' in name :
8873+ n_experts = self .hparams ["num_experts" ]
8874+ assert bid is not None
8875+
8876+ expert_cache = self ._experts_cache .setdefault (bid , {})
8877+ expert_cache [name ] = data_torch
8878+ expert_weights = ["w1" , "w2" , "w3" ]
8879+
8880+ # not enough expert weights to merge
8881+ if len (expert_cache ) < n_experts * len (expert_weights ):
8882+ return []
8883+
8884+ tensors : list [tuple [str , Tensor ]] = []
8885+ for w_name in expert_weights :
8886+ datas : list [Tensor ] = []
8887+
8888+ for xid in range (n_experts ):
8889+ ename = f"model.layers.{ bid } .feed_forward.experts.{ xid } .{ w_name } .weight"
8890+ datas .append (expert_cache [ename ])
8891+ del expert_cache [ename ]
8892+
8893+ data_torch = torch .stack (datas , dim = 0 )
8894+ merged_name = f"layers.{ bid } .feed_forward.experts.{ w_name } .weight"
8895+ new_name = self .map_tensor_name (merged_name )
8896+ tensors .append ((new_name , data_torch ))
8897+
8898+ del self ._experts_cache [bid ]
8899+ return tensors
8900+
8901+ return [(self .map_tensor_name (name ), data_torch )]
8902+
8903+ def prepare_tensors (self ):
8904+ super ().prepare_tensors ()
8905+ assert not self ._experts_cache
8906+
8907+
88398908@ModelBase .register ("Lfm2VlForConditionalGeneration" )
88408909class LFM2VLModel (MmprojModel ):
88418910 def __init__ (self , * args , ** kwargs ):
0 commit comments