@@ -6446,114 +6446,61 @@ class T5GemmaModel(TextModel):
64466446 model_arch = gguf .MODEL_ARCH .T5GEMMA
64476447
64486448 def __init__ (self , * args , ** kwargs ):
6449- # Don't call super().__init__() because it tries to find standard layer count parameters
6450- # that don't exist in T5Gemma models (they have encoder.num_hidden_layers instead)
6451-
6452- # Initialize basic attributes manually
6453- self .dir_model = args [0 ] if args else kwargs .get ('dir_model' )
6454- if self .dir_model is None :
6449+ # Load hyperparameters first to modify them for super().__init__()
6450+ dir_model : Path = args [0 ] if args else kwargs .get ('dir_model' )
6451+ if dir_model is None :
64556452 raise ValueError ("dir_model is required" )
6456- self .ftype = args [1 ] if len (args ) > 1 else kwargs .get ('ftype' )
6457- if self .ftype is None :
6458- raise ValueError ("ftype is required" )
6459- self .fname_out = args [2 ] if len (args ) > 2 else kwargs .get ('fname_out' )
6460- if self .fname_out is None :
6461- raise ValueError ("fname_out is required" )
6462- self .is_big_endian = kwargs .get ('is_big_endian' , False )
6463- self .endianess = gguf .GGUFEndian .BIG if self .is_big_endian else gguf .GGUFEndian .LITTLE
6464- self .use_temp_file = kwargs .get ('use_temp_file' , False )
6465- self .lazy = not kwargs .get ('eager' , False )
6466- self .remote_hf_model_id = kwargs .get ('remote_hf_model_id' )
6467- self .metadata_override = kwargs .get ('metadata_override' )
6468- self .model_name = kwargs .get ('model_name' )
6469- self .dir_model_card = self .dir_model
6470-
6471- # Load model parts
6472- if self .remote_hf_model_id is not None :
6473- self .is_safetensors = True
6474- def get_remote_tensors () -> Iterator [tuple [str , Tensor ]]:
6475- if self .remote_hf_model_id is None :
6476- raise ValueError ("remote_hf_model_id is required for remote models" )
6477- logger .info (f"Using remote model with HuggingFace id: { self .remote_hf_model_id } " )
6478- remote_tensors = gguf .utility .SafetensorRemote .get_list_tensors_hf_model (self .remote_hf_model_id )
6479- self .tensor_names = set (name for name in remote_tensors .keys ())
6480- for name , remote_tensor in gguf .utility .SafetensorRemote .get_list_tensors_hf_model (self .remote_hf_model_id ).items ():
6481- yield (name , LazyTorchTensor .from_remote_tensor (remote_tensor ))
6482- self .get_tensors = get_remote_tensors
6483- else :
6484- self .part_names = ModelBase .get_model_part_names (self .dir_model , "model" , ".safetensors" )
6485- self .is_safetensors = len (self .part_names ) > 0
6486- if not self .is_safetensors :
6487- self .part_names = ModelBase .get_model_part_names (self .dir_model , "pytorch_model" , ".bin" )
6488-
6489- # Load hyperparameters
6490- self .hparams = kwargs .get ('hparams' ) or ModelBase .load_hparams (self .dir_model )
6491- self .tensor_names = None
6492-
6493- # Apply heuristics to figure out typical tensor encoding
6494- if self .ftype == gguf .LlamaFileType .GUESSED :
6495- _ , first_tensor = next (self .get_tensors ())
6496- if first_tensor .dtype == torch .float16 :
6497- logger .info (f"choosing --outtype f16 from first tensor type ({ first_tensor .dtype } )" )
6498- self .ftype = gguf .LlamaFileType .MOSTLY_F16
6499- else :
6500- logger .info (f"choosing --outtype bf16 from first tensor type ({ first_tensor .dtype } )" )
6501- self .ftype = gguf .LlamaFileType .MOSTLY_BF16
6502-
6503- # Configure GGUF Writer
6504- self .gguf_writer = gguf .GGUFWriter (
6505- path = None ,
6506- arch = gguf .MODEL_ARCH_NAMES [self .model_arch ],
6507- endianess = self .endianess ,
6508- use_temp_file = self .use_temp_file ,
6509- split_max_tensors = kwargs .get ('split_max_tensors' , 0 ),
6510- split_max_size = kwargs .get ('split_max_size' , 0 ),
6511- dry_run = kwargs .get ('dry_run' , False ),
6512- small_first_shard = kwargs .get ('small_first_shard' , False )
6513- )
6514-
6453+
6454+ hparams = kwargs .get ("hparams" ) or ModelBase .load_hparams (dir_model )
6455+ encoder_config = hparams .get ("encoder" , {})
6456+ # Add num_hidden_layers to hparams so super().__init__() can find it
6457+ hparams ["num_hidden_layers" ] = encoder_config .get ("num_hidden_layers" , 0 )
6458+ kwargs ["hparams" ] = hparams
6459+
6460+ # Now call super().__init__() with modified hparams
6461+ super ().__init__ (* args , ** kwargs )
6462+
65156463 # T5Gemma specific initialization
65166464 self .is_encoder_decoder = True
6517-
6465+
65186466 # Dynamically get encoder and decoder configurations
6519- encoder_config = self .hparams .get ("encoder" , {})
65206467 decoder_config = self .hparams .get ("decoder" , {})
6521-
6468+
65226469 # Dynamically set encoder and decoder layer counts
65236470 self .encoder_block_count = encoder_config .get ("num_hidden_layers" , 0 )
65246471 self .decoder_block_count = decoder_config .get ("num_hidden_layers" , 0 )
6525-
6472+
65266473 # Set block_count to encoder_block_count for tensor mapping
65276474 self .block_count = self .encoder_block_count
6528-
6475+
65296476 # Initialize tensor mapping using encoder layer count
65306477 self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .encoder_block_count )
65316478
65326479 def set_vocab (self ):
65336480 # T5Gemma uses BPE tokenizer - read directly from tokenizer.json
65346481 import json
6535-
6482+
65366483 tokenizer_json_path = self .dir_model / "tokenizer.json"
65376484 if not tokenizer_json_path .exists ():
65386485 logger .warning ("tokenizer.json not found, falling back to GPT2 method" )
65396486 self ._set_vocab_gpt2 ()
65406487 return
6541-
6488+
65426489 try :
65436490 with open (tokenizer_json_path , 'r' , encoding = 'utf-8' ) as f :
65446491 tokenizer_data = json .load (f )
6545-
6492+
65466493 # Extract vocabulary from tokenizer.json
65476494 vocab = tokenizer_data .get ("model" , {}).get ("vocab" , {})
65486495 vocab_size = self .hparams .get ("vocab_size" , len (vocab ))
6549-
6496+
65506497 # Create tokens and types lists
65516498 tokens = []
65526499 toktypes = []
6553-
6500+
65546501 # Create reverse mapping from id to token
65556502 id_to_token = {v : k for k , v in vocab .items ()}
6556-
6503+
65576504 for i in range (vocab_size ):
65586505 if i in id_to_token :
65596506 token = id_to_token [i ]
@@ -6566,7 +6513,7 @@ def set_vocab(self):
65666513 else :
65676514 tokens .append (f"[PAD{ i } ]" )
65686515 toktypes .append (gguf .TokenType .UNUSED )
6569-
6516+
65706517 # Extract merges from tokenizer.json if available
65716518 merges = []
65726519 if "merges" in tokenizer_data and tokenizer_data ["merges" ]:
@@ -6577,7 +6524,7 @@ def set_vocab(self):
65776524 logger .info (f"Found { len (merges )} merges in tokenizer.json model section" )
65786525 else :
65796526 logger .warning ("No merges found in tokenizer.json" )
6580-
6527+
65816528 # Convert merges to the format expected by GGUF
65826529 if merges :
65836530 # merges are in format [["token1", "token2"], ...]
@@ -6587,27 +6534,27 @@ def set_vocab(self):
65876534 if len (merge ) == 2 :
65886535 gguf_merges .append (f"{ merge [0 ]} { merge [1 ]} " )
65896536 merges = gguf_merges
6590-
6537+
65916538 # Add to GGUF
65926539 self .gguf_writer .add_tokenizer_model ("gpt2" )
65936540 self .gguf_writer .add_tokenizer_pre ("default" )
65946541 self .gguf_writer .add_token_list (tokens )
65956542 self .gguf_writer .add_token_types (toktypes )
65966543 if merges :
65976544 self .gguf_writer .add_token_merges (merges )
6598-
6545+
65996546 # Add special tokens
66006547 special_vocab = gguf .SpecialVocab (self .dir_model , load_merges = False )
66016548 special_vocab .add_to_gguf (self .gguf_writer )
6602-
6549+
66036550 logger .info (f"Successfully loaded T5Gemma vocabulary with { len (tokens )} tokens" )
6604-
6551+
66056552 except Exception as e :
66066553 logger .warning (f"Failed to load T5Gemma tokenizer directly: { e } " )
66076554 self ._set_vocab_gpt2 ()
6608-
6555+
66096556 special_vocab = gguf .SpecialVocab (self .dir_model , load_merges = False )
6610-
6557+
66116558 # Dynamically set special tokens from config instead of hardcoding
66126559 if "eos_token_id" in self .hparams :
66136560 eos_token_ids = self .hparams ["eos_token_id" ]
@@ -6617,7 +6564,7 @@ def set_vocab(self):
66176564 elif isinstance (eos_token_ids , list ) and len (eos_token_ids ) == 1 :
66186565 # If only one end token, use it as end_of_turn
66196566 special_vocab ._set_special_token ("end_of_turn" , eos_token_ids [0 ])
6620-
6567+
66216568 # Dynamically set start_of_turn, usually end_of_turn - 1
66226569 if "eos_token_id" in self .hparams :
66236570 eos_token_ids = self .hparams ["eos_token_id" ]
@@ -6629,16 +6576,16 @@ def set_vocab(self):
66296576 # Use end_of_turn - 1 as start_of_turn
66306577 start_of_turn_id = eos_token_ids [0 ] - 1
66316578 special_vocab ._set_special_token ("start_of_turn" , start_of_turn_id )
6632-
6579+
66336580 special_vocab .add_to_gguf (self .gguf_writer )
6634-
6581+
66356582 if "pad_token_id" in self .hparams :
66366583 self .gguf_writer .add_pad_token_id (self .hparams ["pad_token_id" ])
66376584
66386585 # Dynamically set special token IDs
66396586 if "pad_token_id" in self .hparams :
66406587 self .gguf_writer .add_pad_token_id (self .hparams ["pad_token_id" ])
6641-
6588+
66426589 # Dynamically set multiple end tokens
66436590 if "eos_token_id" in self .hparams :
66446591 eos_token_ids = self .hparams ["eos_token_id" ]
@@ -6650,7 +6597,7 @@ def set_vocab(self):
66506597 def set_gguf_parameters (self ):
66516598 # Dynamically set encoder parameters
66526599 encoder_config = self .hparams ["encoder" ]
6653-
6600+
66546601 if "max_position_embeddings" in encoder_config :
66556602 self .gguf_writer .add_context_length (encoder_config ["max_position_embeddings" ])
66566603 if "hidden_size" in encoder_config :
@@ -6680,32 +6627,34 @@ def set_gguf_parameters(self):
66806627 decoder_config = self .hparams ["decoder" ]
66816628 if "cross_attention_hidden_size" in decoder_config :
66826629 self .gguf_writer .add_key_value ("cross_attention_hidden_size" , decoder_config ["cross_attention_hidden_size" ], gguf .GGUFValueType .UINT32 )
6683-
6630+
66846631 # Dynamically set global parameters
66856632 if "vocab_size" in encoder_config :
66866633 self .gguf_writer .add_vocab_size (encoder_config ["vocab_size" ])
6687-
6634+
66886635 if "dropout_rate" in self .hparams :
66896636 self .gguf_writer .add_key_value ("dropout_rate" , self .hparams ["dropout_rate" ], gguf .GGUFValueType .FLOAT32 )
66906637 if "classifier_dropout_rate" in self .hparams :
66916638 self .gguf_writer .add_key_value ("classifier_dropout_rate" , self .hparams ["classifier_dropout_rate" ], gguf .GGUFValueType .FLOAT32 )
6692-
6639+
66936640 if "initializer_range" in self .hparams :
66946641 self .gguf_writer .add_key_value ("initializer_range" , self .hparams ["initializer_range" ], gguf .GGUFValueType .FLOAT32 )
6695-
6642+
66966643 if "attention_bias" in encoder_config :
66976644 self .gguf_writer .add_key_value ("attention_bias" , encoder_config ["attention_bias" ], gguf .GGUFValueType .BOOL )
66986645 if "attention_dropout" in encoder_config :
66996646 self .gguf_writer .add_key_value ("attention_dropout" , encoder_config ["attention_dropout" ], gguf .GGUFValueType .FLOAT32 )
67006647 if "query_pre_attn_scalar" in encoder_config :
67016648 self .gguf_writer .add_key_value ("query_pre_attn_scalar" , encoder_config ["query_pre_attn_scalar" ], gguf .GGUFValueType .UINT32 )
6702-
6649+
67036650 # Dynamically set encoder's other parameters
6651+ # Only include specific keys that are known to be useful for T5Gemma
6652+ encoder_keys_to_include = [
6653+ "classifier_dropout_rate" , "dropout_rate" , "initializer_range" ,
6654+ "model_type" , "torch_dtype" , "use_cache" , "hidden_activation"
6655+ ]
67046656 for key , value in encoder_config .items ():
6705- if key not in ["max_position_embeddings" , "hidden_size" , "num_hidden_layers" , "intermediate_size" ,
6706- "num_attention_heads" , "num_key_value_heads" , "head_dim" , "rms_norm_eps" ,
6707- "sliding_window" , "attn_logit_softcapping" , "final_logit_softcapping" ,
6708- "rope_theta" , "attention_bias" , "attention_dropout" , "query_pre_attn_scalar" , "vocab_size" ]:
6657+ if key in encoder_keys_to_include :
67096658 if isinstance (value , bool ):
67106659 self .gguf_writer .add_key_value (f"encoder_{ key } " , value , gguf .GGUFValueType .BOOL )
67116660 elif isinstance (value , int ):
@@ -6714,10 +6663,20 @@ def set_gguf_parameters(self):
67146663 self .gguf_writer .add_key_value (f"encoder_{ key } " , value , gguf .GGUFValueType .FLOAT32 )
67156664 elif isinstance (value , str ):
67166665 self .gguf_writer .add_key_value (f"encoder_{ key } " , value , gguf .GGUFValueType .STRING )
6717-
6666+
67186667 # Dynamically set decoder's other parameters
6668+ # Only include specific keys that are known to be useful for T5Gemma
6669+ decoder_keys_to_include = [
6670+ "classifier_dropout_rate" , "dropout_rate" , "initializer_range" ,
6671+ "model_type" , "torch_dtype" , "use_cache" , "hidden_activation" ,
6672+ "is_decoder" , "max_position_embeddings" , "hidden_size" ,
6673+ "intermediate_size" , "num_attention_heads" , "num_key_value_heads" ,
6674+ "head_dim" , "rms_norm_eps" , "sliding_window" , "attn_logit_softcapping" ,
6675+ "final_logit_softcapping" , "rope_theta" , "attention_bias" ,
6676+ "attention_dropout" , "query_pre_attn_scalar" , "vocab_size"
6677+ ]
67196678 for key , value in decoder_config .items ():
6720- if key not in [ "cross_attention_hidden_size" ] :
6679+ if key in decoder_keys_to_include :
67216680 if isinstance (value , bool ):
67226681 self .gguf_writer .add_key_value (f"decoder_{ key } " , value , gguf .GGUFValueType .BOOL )
67236682 elif isinstance (value , int ):
@@ -6726,10 +6685,10 @@ def set_gguf_parameters(self):
67266685 self .gguf_writer .add_key_value (f"decoder_{ key } " , value , gguf .GGUFValueType .FLOAT32 )
67276686 elif isinstance (value , str ):
67286687 self .gguf_writer .add_key_value (f"decoder_{ key } " , value , gguf .GGUFValueType .STRING )
6729-
6688+
67306689 # T5 models typically use 32 relative attention buckets
67316690 self .gguf_writer .add_relative_attn_buckets_count (32 )
6732-
6691+
67336692 self .gguf_writer .add_file_type (self .ftype )
67346693
67356694 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
@@ -6761,20 +6720,20 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
67616720 n_head_enc = self .hparams .get ("encoder_num_attention_heads" , 8 )
67626721 n_head_dec = self .hparams .get ("decoder_num_attention_heads" , 8 )
67636722 n_rel_attn_bkts = self .hparams .get ("relative_buckets_count" , 32 )
6764-
6723+
67656724 # Generate relative attention bias for encoder layers
67666725 for i in range (self .block_count ):
67676726 # Encoder relative attention bias - shape should be (n_rel_attn_bkts, n_head)
67686727 rel_bias_enc = torch .zeros (n_rel_attn_bkts , n_head_enc , dtype = torch .float16 )
6769- yield f"enc.blk. { i } .attn_rel_b.weight" , rel_bias_enc
6770-
6728+ yield self . format_tensor_name ( gguf . MODEL_TENSOR . ENC_ATTN_REL_B , i ) , rel_bias_enc
6729+
67716730 # Decoder relative attention bias - shape should be (n_rel_attn_bkts, n_head)
67726731 rel_bias_dec = torch .zeros (n_rel_attn_bkts , n_head_dec , dtype = torch .float16 )
6773- yield f"dec.blk. { i } .attn_rel_b.weight" , rel_bias_dec
6774-
6732+ yield self . format_tensor_name ( gguf . MODEL_TENSOR . DEC_ATTN_REL_B , i ) , rel_bias_dec
6733+
67756734 # Decoder cross attention relative bias - shape should be (n_rel_attn_bkts, n_head)
67766735 rel_bias_cross = torch .zeros (n_rel_attn_bkts , n_head_dec , dtype = torch .float16 )
6777- yield f"dec.blk. { i } .cross_attn_rel_b.weight" , rel_bias_cross
6736+ yield self . format_tensor_name ( gguf . MODEL_TENSOR . DEC_CROSS_ATTN_REL_B , i ) , rel_bias_cross
67786737
67796738
67806739@ModelBase .register ("T5EncoderModel" )
0 commit comments