@@ -218,6 +218,8 @@ def from_model_architecture(model_architecture):
218218 return BertModel
219219 if model_architecture == "NomicBertModel" :
220220 return NomicBertModel
221+ if model_architecture == "GemmaForCausalLM" :
222+ return GemmaModel
221223 return Model
222224
223225 def _is_model_safetensors (self ) -> bool :
@@ -277,6 +279,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
277279 return gguf .MODEL_ARCH .BERT
278280 if arch == "NomicBertModel" :
279281 return gguf .MODEL_ARCH .NOMIC_BERT
282+ if arch == "GemmaForCausalLM" :
283+ return gguf .MODEL_ARCH .GEMMA
280284
281285 raise NotImplementedError (f'Architecture "{ arch } " not supported!' )
282286
@@ -1786,6 +1790,62 @@ def get_tensors(self):
17861790 yield name , data
17871791
17881792
1793+ class GemmaModel (Model ):
1794+ def set_vocab (self ):
1795+ self ._set_vocab_sentencepiece ()
1796+
1797+ def set_gguf_parameters (self ):
1798+ hparams = self .hparams
1799+ block_count = hparams ["num_hidden_layers" ]
1800+
1801+ self .gguf_writer .add_name (self .dir_model .name )
1802+ self .gguf_writer .add_context_length (hparams ["max_position_embeddings" ])
1803+ self .gguf_writer .add_embedding_length (hparams ["hidden_size" ])
1804+ self .gguf_writer .add_block_count (block_count )
1805+ self .gguf_writer .add_feed_forward_length (hparams ["intermediate_size" ])
1806+ self .gguf_writer .add_head_count (hparams ["num_attention_heads" ])
1807+ self .gguf_writer .add_head_count_kv (self .hparams ["num_key_value_heads" ] if "num_key_value_heads" in hparams else hparams ["num_attention_heads" ])
1808+ self .gguf_writer .add_layer_norm_rms_eps (self .hparams ["rms_norm_eps" ])
1809+ self .gguf_writer .add_key_length (hparams ["head_dim" ])
1810+ self .gguf_writer .add_value_length (hparams ["head_dim" ])
1811+
1812+ def write_tensors (self ):
1813+ block_count = self .hparams .get ("n_layers" , self .hparams .get ("num_hidden_layers" , self .hparams .get ("n_layer" )))
1814+ tensor_map = gguf .get_tensor_name_map (self .model_arch , block_count )
1815+
1816+ for name , data_torch in self .get_tensors ():
1817+ # ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
1818+ if name .endswith ("norm.weight" ):
1819+ data_torch = data_torch + 1
1820+
1821+ old_dtype = data_torch .dtype
1822+
1823+ # convert any unsupported data types to float32
1824+ if data_torch .dtype not in (torch .float16 , torch .float32 ):
1825+ data_torch = data_torch .to (torch .float32 )
1826+
1827+ data = data_torch .squeeze ().numpy ()
1828+
1829+ # map tensor names
1830+ new_name = tensor_map .get_name (name , try_suffixes = (".weight" , ".bias" ))
1831+ if new_name is None :
1832+ print (f"Can not map tensor { name !r} " )
1833+ sys .exit ()
1834+
1835+ n_dims = len (data .shape )
1836+ data_dtype = data .dtype
1837+
1838+ data = data .astype (np .float32 )
1839+
1840+ # if f16 desired, convert any float32 2-dim weight tensors to float16
1841+ if self .ftype == 1 and data_dtype == np .float32 and name .endswith (".weight" ) and n_dims == 2 :
1842+ data = data .astype (np .float16 )
1843+
1844+ print (f"{ new_name } , n_dims = { n_dims } , { old_dtype } --> { data .dtype } " )
1845+
1846+ self .gguf_writer .add_tensor (new_name , data )
1847+
1848+
17891849###### CONVERSION LOGIC ######
17901850
17911851
0 commit comments