@@ -654,6 +654,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
654654 if chkhsh == "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890" :
655655 # ref: https://huggingface.co/moonshotai/Kimi-K2-Base
656656 res = "kimi-k2"
657+ if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206" :
658+ # ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
659+ res = "bailingmoe2"
657660
658661 if res is None :
659662 logger .warning ("\n " )
@@ -4461,6 +4464,103 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
44614464 name = name .removeprefix ("transformer." )
44624465 return [(self .map_tensor_name (name ), data_torch )]
44634466
4467+ @Model .register ("BailingMoeV2ForCausalLM" )
4468+ class BailingMoeV2Model (Model ):
4469+ model_arch = gguf .MODEL_ARCH .BAILINGMOE2
4470+
4471+ def __init__ (self , * args , ** kwargs ):
4472+ super ().__init__ (* args , ** kwargs )
4473+ if nextn_layers := self .hparams .get ("num_nextn_predict_layers" , 0 ):
4474+ self .block_count = self .hparams ["num_hidden_layers" ] + nextn_layers
4475+ self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .block_count )
4476+
4477+ def set_vocab (self ):
4478+ self ._set_vocab_gpt2 ()
4479+
4480+ def set_gguf_parameters (self ):
4481+ super ().set_gguf_parameters ()
4482+ hparams = self .hparams
4483+ if (rope_dim := hparams .get ("head_dim" )) is None :
4484+ rope_dim = hparams ["hidden_size" ] // hparams ["num_attention_heads" ]
4485+
4486+ self .gguf_writer .add_rope_dimension_count (int (rope_dim * self .hparams .get ("partial_rotary_factor" , 0.5 )))
4487+ rope_scaling = self .hparams .get ("rope_scaling" ) or {}
4488+ if rope_scaling .get ("rope_type" , rope_scaling .get ("type" )) == "yarn" and "factor" in rope_scaling :
4489+ self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .YARN )
4490+ self .gguf_writer .add_rope_scaling_factor (rope_scaling ["factor" ])
4491+ self .gguf_writer .add_rope_scaling_orig_ctx_len (rope_scaling ["original_max_position_embeddings" ])
4492+ else :
4493+ self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .NONE )
4494+ self .gguf_writer .add_leading_dense_block_count (hparams ["first_k_dense_replace" ])
4495+ self .gguf_writer .add_vocab_size (hparams ["vocab_size" ])
4496+ self .gguf_writer .add_expert_feed_forward_length (hparams ["moe_intermediate_size" ])
4497+ self .gguf_writer .add_expert_shared_feed_forward_length (hparams ["moe_shared_expert_intermediate_size" ])
4498+ self .gguf_writer .add_expert_weights_scale (hparams ["routed_scaling_factor" ])
4499+ self .gguf_writer .add_expert_count (hparams ["num_experts" ])
4500+ self .gguf_writer .add_expert_shared_count (hparams ["num_shared_experts" ])
4501+ self .gguf_writer .add_expert_group_count (hparams ["n_group" ])
4502+ self .gguf_writer .add_expert_group_used_count (hparams ["topk_group" ])
4503+ self .gguf_writer .add_expert_weights_norm (hparams ["norm_topk_prob" ])
4504+
4505+ if hparams ["score_function" ] == "sigmoid" :
4506+ self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SIGMOID )
4507+ elif hparams ["score_function" ] == "softmax" :
4508+ self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SOFTMAX )
4509+ else :
4510+ raise ValueError (f"Unsupported score_function value: { hparams ['score_function' ]} " )
4511+
4512+ if (nextn_layers := self .hparams .get ("num_nextn_predict_layers" )) is not None :
4513+ self .gguf_writer .add_nextn_predict_layers (nextn_layers )
4514+
4515+ _experts : list [dict [str , Tensor ]] | None = None
4516+
4517+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
4518+ if "mlp.experts" in name :
4519+ n_experts = self .hparams ["num_experts" ]
4520+ assert bid is not None
4521+
4522+ tensors : list [tuple [str , Tensor ]] = []
4523+
4524+ if self ._experts is None :
4525+ self ._experts = [{} for _ in range (self .block_count )]
4526+
4527+ self ._experts [bid ][name ] = data_torch
4528+
4529+ if len (self ._experts [bid ]) >= n_experts * 3 :
4530+ # merge the experts into a single 3d tensor
4531+ for w_name in ["down_proj" , "gate_proj" , "up_proj" ]:
4532+ datas : list [Tensor ] = []
4533+
4534+ for xid in range (n_experts ):
4535+ ename = f"model.layers.{ bid } .mlp.experts.{ xid } .{ w_name } .weight"
4536+ datas .append (self ._experts [bid ][ename ])
4537+ del self ._experts [bid ][ename ]
4538+
4539+ data_torch = torch .stack (datas , dim = 0 )
4540+
4541+ merged_name = f"model.layers.{ bid } .mlp.experts.{ w_name } .weight"
4542+
4543+ new_name = self .map_tensor_name (merged_name )
4544+
4545+ tensors .append ((new_name , data_torch ))
4546+
4547+ return tensors
4548+
4549+ if name .endswith (".expert_bias" ):
4550+ name = name .replace (".expert_bias" , ".expert_bias.bias" )
4551+
4552+ return [(self .map_tensor_name (name ), data_torch )]
4553+
4554+ def prepare_tensors (self ):
4555+ super ().prepare_tensors ()
4556+
4557+ if self ._experts is not None :
4558+ # flatten `list[dict[str, Tensor]]` into `list[str]`
4559+ experts = [k for d in self ._experts for k in d .keys ()]
4560+ if len (experts ) > 0 :
4561+ raise ValueError (f"Unprocessed experts: { experts } " )
4562+
4563+
44644564###### CONVERSION LOGIC ######
44654565
44664566
0 commit comments