@@ -4974,6 +4974,123 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
49744974 yield (new_name , data_torch )
49754975
49764976
4977+ @ModelBase .register ("JambaForCausalLM" )
4978+ class JambaModel (TextModel ):
4979+ model_arch = gguf .MODEL_ARCH .JAMBA
4980+
4981+ def get_vocab_base_pre (self , tokenizer ) -> str :
4982+ del tokenizer # unused
4983+
4984+ return "gpt-2"
4985+
4986+ def set_vocab (self ):
4987+ if (self .dir_model / "tokenizer.model" ).is_file ():
4988+ # Using Jamba's tokenizer.json causes errors on model load
4989+ # (something about "byte not found in vocab"),
4990+ # but there's a working tokenizer.model
4991+ self ._set_vocab_sentencepiece ()
4992+ else :
4993+ # Some Jamba models only have a tokenizer.json, which works.
4994+ self ._set_vocab_gpt2 ()
4995+
4996+ def set_gguf_parameters (self ):
4997+ d_model = self .find_hparam (["hidden_size" , "mamba_d_model" ])
4998+ d_conv = self .find_hparam (["mamba_d_conv" ], optional = True ) or 4
4999+ d_inner = self .hparams ["mamba_expand" ] * d_model
5000+ d_state = self .find_hparam (["mamba_d_state" ], optional = True ) or 16
5001+ # ceiling division
5002+ # ref: https://stackoverflow.com/a/17511341/22827863
5003+ # ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
5004+ dt_rank = self .find_hparam (["mamba_dt_rank" ], optional = True ) or - (d_model // - 16 )
5005+ rms_norm_eps = self .find_hparam (["layer_norm_epsilon" , "rms_norm_eps" ], optional = True ) or 1e-6
5006+ n_kv_head = self .hparams ["num_key_value_heads" ]
5007+ attn_offset = self .hparams ["attn_layer_offset" ]
5008+ attn_period = self .hparams ["attn_layer_period" ]
5009+ n_kv_vec = [0 for _ in range (attn_offset )] + [
5010+ n_kv_head if (i - attn_offset ) % attn_period == 0 else 0 for i in range (attn_offset , self .block_count )
5011+ ]
5012+
5013+ self .gguf_writer .add_block_count (self .block_count )
5014+ self .gguf_writer .add_context_length (self .find_hparam (["max_position_embeddings" , "n_ctx" ]))
5015+ self .gguf_writer .add_embedding_length (d_model )
5016+ self .gguf_writer .add_feed_forward_length (self .hparams ["intermediate_size" ])
5017+ self .gguf_writer .add_head_count (self .hparams ["num_attention_heads" ])
5018+ self .gguf_writer .add_head_count_kv (n_kv_vec )
5019+ self .gguf_writer .add_ssm_conv_kernel (d_conv )
5020+ self .gguf_writer .add_ssm_inner_size (d_inner )
5021+ self .gguf_writer .add_ssm_state_size (d_state )
5022+ self .gguf_writer .add_ssm_time_step_rank (dt_rank )
5023+ self .gguf_writer .add_layer_norm_rms_eps (rms_norm_eps )
5024+ self .gguf_writer .add_expert_count (self .hparams ["num_experts" ])
5025+ self .gguf_writer .add_expert_used_count (self .hparams ["num_experts_per_tok" ])
5026+ self .gguf_writer .add_file_type (self .ftype )
5027+
5028+ _experts : list [dict [str , Tensor ]] | None = None
5029+
5030+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
5031+
5032+ # Mini-Jamba
5033+ name = name .replace (".moe." , ".feed_forward." )
5034+ if bid is not None :
5035+ moe_offset = self .hparams ["expert_layer_offset" ]
5036+ moe_period = self .hparams ["expert_layer_period" ]
5037+
5038+ if not (bid >= moe_offset and (bid - moe_offset ) % moe_period == 0 ):
5039+ name = name .replace (".experts.0." , "." )
5040+
5041+ # process the experts separately
5042+ if ".feed_forward.experts." in name :
5043+ n_experts = self .hparams ["num_experts" ]
5044+
5045+ assert bid is not None
5046+
5047+ if self ._experts is None :
5048+ self ._experts = [{} for _ in range (self .block_count )]
5049+
5050+ self ._experts [bid ][name ] = data_torch
5051+
5052+ if len (self ._experts [bid ]) >= n_experts * 3 :
5053+
5054+ # merge the experts into a single 3d tensor
5055+ for wid in ["down_proj" , "gate_proj" , "up_proj" ]:
5056+ datas : list [Tensor ] = []
5057+
5058+ for xid in range (n_experts ):
5059+ ename = f"model.layers.{ bid } .feed_forward.experts.{ xid } .{ wid } .weight"
5060+ datas .append (self ._experts [bid ][ename ])
5061+ del self ._experts [bid ][ename ]
5062+
5063+ data_torch = torch .stack (datas , dim = 0 )
5064+
5065+ # using the same merged name as qwen2moe
5066+ merged_name = f"model.layers.{ bid } .mlp.experts.{ wid } .weight"
5067+
5068+ new_name = self .map_tensor_name (merged_name )
5069+
5070+ yield new_name , data_torch
5071+ return
5072+
5073+ new_name = self .map_tensor_name (name )
5074+
5075+ if self .match_model_tensor_name (new_name , gguf .MODEL_TENSOR .SSM_CONV1D , bid ):
5076+ data_torch = data_torch .squeeze ()
5077+
5078+ if name .endswith (".A_log" ):
5079+ logger .debug ("A_log --> A ==> " + new_name )
5080+ data_torch = - torch .exp (data_torch )
5081+
5082+ yield (new_name , data_torch )
5083+
5084+ def prepare_tensors (self ):
5085+ super ().prepare_tensors ()
5086+
5087+ if self ._experts is not None :
5088+ # flatten `list[dict[str, Tensor]]` into `list[str]`
5089+ experts = [k for d in self ._experts for k in d .keys ()]
5090+ if len (experts ) > 0 :
5091+ raise ValueError (f"Unprocessed experts: { experts } " )
5092+
5093+
49775094@ModelBase .register ("CohereForCausalLM" )
49785095class CommandR2Model (TextModel ):
49795096 model_arch = gguf .MODEL_ARCH .COMMAND_R
0 commit comments