Skip to content

Commit f892c18

Browse files
arch : add T5Gemma encoder-decoder architecture support with improvements (#14940)
- Add T5Gemma model support with proper encoder-decoder architecture - Use super().__init__() instead of manual initialization for better inheritance - Use format_tensor_name() for consistent tensor naming - Explicitly enumerate included keys instead of excluding keys - Add proper type annotations for better type safety - Fix all trailing whitespace issues - Support relative attention bias tensors generation - Handle T5Gemma-specific post-layer normalization tensors - Implement proper tokenizer handling for BPE tokenizer - Add comprehensive tensor mapping for all T5Gemma components
1 parent f5144c1 commit f892c18

File tree

2 files changed

+68
-109
lines changed

2 files changed

+68
-109
lines changed

convert_hf_to_gguf.py

Lines changed: 67 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -6446,114 +6446,61 @@ class T5GemmaModel(TextModel):
64466446
model_arch = gguf.MODEL_ARCH.T5GEMMA
64476447

64486448
def __init__(self, *args, **kwargs):
6449-
# Don't call super().__init__() because it tries to find standard layer count parameters
6450-
# that don't exist in T5Gemma models (they have encoder.num_hidden_layers instead)
6451-
6452-
# Initialize basic attributes manually
6453-
self.dir_model = args[0] if args else kwargs.get('dir_model')
6454-
if self.dir_model is None:
6449+
# Load hyperparameters first to modify them for super().__init__()
6450+
dir_model: Path = args[0] if args else kwargs.get('dir_model')
6451+
if dir_model is None:
64556452
raise ValueError("dir_model is required")
6456-
self.ftype = args[1] if len(args) > 1 else kwargs.get('ftype')
6457-
if self.ftype is None:
6458-
raise ValueError("ftype is required")
6459-
self.fname_out = args[2] if len(args) > 2 else kwargs.get('fname_out')
6460-
if self.fname_out is None:
6461-
raise ValueError("fname_out is required")
6462-
self.is_big_endian = kwargs.get('is_big_endian', False)
6463-
self.endianess = gguf.GGUFEndian.BIG if self.is_big_endian else gguf.GGUFEndian.LITTLE
6464-
self.use_temp_file = kwargs.get('use_temp_file', False)
6465-
self.lazy = not kwargs.get('eager', False)
6466-
self.remote_hf_model_id = kwargs.get('remote_hf_model_id')
6467-
self.metadata_override = kwargs.get('metadata_override')
6468-
self.model_name = kwargs.get('model_name')
6469-
self.dir_model_card = self.dir_model
6470-
6471-
# Load model parts
6472-
if self.remote_hf_model_id is not None:
6473-
self.is_safetensors = True
6474-
def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
6475-
if self.remote_hf_model_id is None:
6476-
raise ValueError("remote_hf_model_id is required for remote models")
6477-
logger.info(f"Using remote model with HuggingFace id: {self.remote_hf_model_id}")
6478-
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(self.remote_hf_model_id)
6479-
self.tensor_names = set(name for name in remote_tensors.keys())
6480-
for name, remote_tensor in gguf.utility.SafetensorRemote.get_list_tensors_hf_model(self.remote_hf_model_id).items():
6481-
yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor))
6482-
self.get_tensors = get_remote_tensors
6483-
else:
6484-
self.part_names = ModelBase.get_model_part_names(self.dir_model, "model", ".safetensors")
6485-
self.is_safetensors = len(self.part_names) > 0
6486-
if not self.is_safetensors:
6487-
self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
6488-
6489-
# Load hyperparameters
6490-
self.hparams = kwargs.get('hparams') or ModelBase.load_hparams(self.dir_model)
6491-
self.tensor_names = None
6492-
6493-
# Apply heuristics to figure out typical tensor encoding
6494-
if self.ftype == gguf.LlamaFileType.GUESSED:
6495-
_, first_tensor = next(self.get_tensors())
6496-
if first_tensor.dtype == torch.float16:
6497-
logger.info(f"choosing --outtype f16 from first tensor type ({first_tensor.dtype})")
6498-
self.ftype = gguf.LlamaFileType.MOSTLY_F16
6499-
else:
6500-
logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
6501-
self.ftype = gguf.LlamaFileType.MOSTLY_BF16
6502-
6503-
# Configure GGUF Writer
6504-
self.gguf_writer = gguf.GGUFWriter(
6505-
path=None,
6506-
arch=gguf.MODEL_ARCH_NAMES[self.model_arch],
6507-
endianess=self.endianess,
6508-
use_temp_file=self.use_temp_file,
6509-
split_max_tensors=kwargs.get('split_max_tensors', 0),
6510-
split_max_size=kwargs.get('split_max_size', 0),
6511-
dry_run=kwargs.get('dry_run', False),
6512-
small_first_shard=kwargs.get('small_first_shard', False)
6513-
)
6514-
6453+
6454+
hparams = kwargs.get("hparams") or ModelBase.load_hparams(dir_model)
6455+
encoder_config = hparams.get("encoder", {})
6456+
# Add num_hidden_layers to hparams so super().__init__() can find it
6457+
hparams["num_hidden_layers"] = encoder_config.get("num_hidden_layers", 0)
6458+
kwargs["hparams"] = hparams
6459+
6460+
# Now call super().__init__() with modified hparams
6461+
super().__init__(*args, **kwargs)
6462+
65156463
# T5Gemma specific initialization
65166464
self.is_encoder_decoder = True
6517-
6465+
65186466
# Dynamically get encoder and decoder configurations
6519-
encoder_config = self.hparams.get("encoder", {})
65206467
decoder_config = self.hparams.get("decoder", {})
6521-
6468+
65226469
# Dynamically set encoder and decoder layer counts
65236470
self.encoder_block_count = encoder_config.get("num_hidden_layers", 0)
65246471
self.decoder_block_count = decoder_config.get("num_hidden_layers", 0)
6525-
6472+
65266473
# Set block_count to encoder_block_count for tensor mapping
65276474
self.block_count = self.encoder_block_count
6528-
6475+
65296476
# Initialize tensor mapping using encoder layer count
65306477
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.encoder_block_count)
65316478

65326479
def set_vocab(self):
65336480
# T5Gemma uses BPE tokenizer - read directly from tokenizer.json
65346481
import json
6535-
6482+
65366483
tokenizer_json_path = self.dir_model / "tokenizer.json"
65376484
if not tokenizer_json_path.exists():
65386485
logger.warning("tokenizer.json not found, falling back to GPT2 method")
65396486
self._set_vocab_gpt2()
65406487
return
6541-
6488+
65426489
try:
65436490
with open(tokenizer_json_path, 'r', encoding='utf-8') as f:
65446491
tokenizer_data = json.load(f)
6545-
6492+
65466493
# Extract vocabulary from tokenizer.json
65476494
vocab = tokenizer_data.get("model", {}).get("vocab", {})
65486495
vocab_size = self.hparams.get("vocab_size", len(vocab))
6549-
6496+
65506497
# Create tokens and types lists
65516498
tokens = []
65526499
toktypes = []
6553-
6500+
65546501
# Create reverse mapping from id to token
65556502
id_to_token = {v: k for k, v in vocab.items()}
6556-
6503+
65576504
for i in range(vocab_size):
65586505
if i in id_to_token:
65596506
token = id_to_token[i]
@@ -6566,7 +6513,7 @@ def set_vocab(self):
65666513
else:
65676514
tokens.append(f"[PAD{i}]")
65686515
toktypes.append(gguf.TokenType.UNUSED)
6569-
6516+
65706517
# Extract merges from tokenizer.json if available
65716518
merges = []
65726519
if "merges" in tokenizer_data and tokenizer_data["merges"]:
@@ -6577,7 +6524,7 @@ def set_vocab(self):
65776524
logger.info(f"Found {len(merges)} merges in tokenizer.json model section")
65786525
else:
65796526
logger.warning("No merges found in tokenizer.json")
6580-
6527+
65816528
# Convert merges to the format expected by GGUF
65826529
if merges:
65836530
# merges are in format [["token1", "token2"], ...]
@@ -6587,27 +6534,27 @@ def set_vocab(self):
65876534
if len(merge) == 2:
65886535
gguf_merges.append(f"{merge[0]} {merge[1]}")
65896536
merges = gguf_merges
6590-
6537+
65916538
# Add to GGUF
65926539
self.gguf_writer.add_tokenizer_model("gpt2")
65936540
self.gguf_writer.add_tokenizer_pre("default")
65946541
self.gguf_writer.add_token_list(tokens)
65956542
self.gguf_writer.add_token_types(toktypes)
65966543
if merges:
65976544
self.gguf_writer.add_token_merges(merges)
6598-
6545+
65996546
# Add special tokens
66006547
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
66016548
special_vocab.add_to_gguf(self.gguf_writer)
6602-
6549+
66036550
logger.info(f"Successfully loaded T5Gemma vocabulary with {len(tokens)} tokens")
6604-
6551+
66056552
except Exception as e:
66066553
logger.warning(f"Failed to load T5Gemma tokenizer directly: {e}")
66076554
self._set_vocab_gpt2()
6608-
6555+
66096556
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
6610-
6557+
66116558
# Dynamically set special tokens from config instead of hardcoding
66126559
if "eos_token_id" in self.hparams:
66136560
eos_token_ids = self.hparams["eos_token_id"]
@@ -6617,7 +6564,7 @@ def set_vocab(self):
66176564
elif isinstance(eos_token_ids, list) and len(eos_token_ids) == 1:
66186565
# If only one end token, use it as end_of_turn
66196566
special_vocab._set_special_token("end_of_turn", eos_token_ids[0])
6620-
6567+
66216568
# Dynamically set start_of_turn, usually end_of_turn - 1
66226569
if "eos_token_id" in self.hparams:
66236570
eos_token_ids = self.hparams["eos_token_id"]
@@ -6629,16 +6576,16 @@ def set_vocab(self):
66296576
# Use end_of_turn - 1 as start_of_turn
66306577
start_of_turn_id = eos_token_ids[0] - 1
66316578
special_vocab._set_special_token("start_of_turn", start_of_turn_id)
6632-
6579+
66336580
special_vocab.add_to_gguf(self.gguf_writer)
6634-
6581+
66356582
if "pad_token_id" in self.hparams:
66366583
self.gguf_writer.add_pad_token_id(self.hparams["pad_token_id"])
66376584

66386585
# Dynamically set special token IDs
66396586
if "pad_token_id" in self.hparams:
66406587
self.gguf_writer.add_pad_token_id(self.hparams["pad_token_id"])
6641-
6588+
66426589
# Dynamically set multiple end tokens
66436590
if "eos_token_id" in self.hparams:
66446591
eos_token_ids = self.hparams["eos_token_id"]
@@ -6650,7 +6597,7 @@ def set_vocab(self):
66506597
def set_gguf_parameters(self):
66516598
# Dynamically set encoder parameters
66526599
encoder_config = self.hparams["encoder"]
6653-
6600+
66546601
if "max_position_embeddings" in encoder_config:
66556602
self.gguf_writer.add_context_length(encoder_config["max_position_embeddings"])
66566603
if "hidden_size" in encoder_config:
@@ -6680,32 +6627,34 @@ def set_gguf_parameters(self):
66806627
decoder_config = self.hparams["decoder"]
66816628
if "cross_attention_hidden_size" in decoder_config:
66826629
self.gguf_writer.add_key_value("cross_attention_hidden_size", decoder_config["cross_attention_hidden_size"], gguf.GGUFValueType.UINT32)
6683-
6630+
66846631
# Dynamically set global parameters
66856632
if "vocab_size" in encoder_config:
66866633
self.gguf_writer.add_vocab_size(encoder_config["vocab_size"])
6687-
6634+
66886635
if "dropout_rate" in self.hparams:
66896636
self.gguf_writer.add_key_value("dropout_rate", self.hparams["dropout_rate"], gguf.GGUFValueType.FLOAT32)
66906637
if "classifier_dropout_rate" in self.hparams:
66916638
self.gguf_writer.add_key_value("classifier_dropout_rate", self.hparams["classifier_dropout_rate"], gguf.GGUFValueType.FLOAT32)
6692-
6639+
66936640
if "initializer_range" in self.hparams:
66946641
self.gguf_writer.add_key_value("initializer_range", self.hparams["initializer_range"], gguf.GGUFValueType.FLOAT32)
6695-
6642+
66966643
if "attention_bias" in encoder_config:
66976644
self.gguf_writer.add_key_value("attention_bias", encoder_config["attention_bias"], gguf.GGUFValueType.BOOL)
66986645
if "attention_dropout" in encoder_config:
66996646
self.gguf_writer.add_key_value("attention_dropout", encoder_config["attention_dropout"], gguf.GGUFValueType.FLOAT32)
67006647
if "query_pre_attn_scalar" in encoder_config:
67016648
self.gguf_writer.add_key_value("query_pre_attn_scalar", encoder_config["query_pre_attn_scalar"], gguf.GGUFValueType.UINT32)
6702-
6649+
67036650
# Dynamically set encoder's other parameters
6651+
# Only include specific keys that are known to be useful for T5Gemma
6652+
encoder_keys_to_include = [
6653+
"classifier_dropout_rate", "dropout_rate", "initializer_range",
6654+
"model_type", "torch_dtype", "use_cache", "hidden_activation"
6655+
]
67046656
for key, value in encoder_config.items():
6705-
if key not in ["max_position_embeddings", "hidden_size", "num_hidden_layers", "intermediate_size",
6706-
"num_attention_heads", "num_key_value_heads", "head_dim", "rms_norm_eps",
6707-
"sliding_window", "attn_logit_softcapping", "final_logit_softcapping",
6708-
"rope_theta", "attention_bias", "attention_dropout", "query_pre_attn_scalar", "vocab_size"]:
6657+
if key in encoder_keys_to_include:
67096658
if isinstance(value, bool):
67106659
self.gguf_writer.add_key_value(f"encoder_{key}", value, gguf.GGUFValueType.BOOL)
67116660
elif isinstance(value, int):
@@ -6714,10 +6663,20 @@ def set_gguf_parameters(self):
67146663
self.gguf_writer.add_key_value(f"encoder_{key}", value, gguf.GGUFValueType.FLOAT32)
67156664
elif isinstance(value, str):
67166665
self.gguf_writer.add_key_value(f"encoder_{key}", value, gguf.GGUFValueType.STRING)
6717-
6666+
67186667
# Dynamically set decoder's other parameters
6668+
# Only include specific keys that are known to be useful for T5Gemma
6669+
decoder_keys_to_include = [
6670+
"classifier_dropout_rate", "dropout_rate", "initializer_range",
6671+
"model_type", "torch_dtype", "use_cache", "hidden_activation",
6672+
"is_decoder", "max_position_embeddings", "hidden_size",
6673+
"intermediate_size", "num_attention_heads", "num_key_value_heads",
6674+
"head_dim", "rms_norm_eps", "sliding_window", "attn_logit_softcapping",
6675+
"final_logit_softcapping", "rope_theta", "attention_bias",
6676+
"attention_dropout", "query_pre_attn_scalar", "vocab_size"
6677+
]
67196678
for key, value in decoder_config.items():
6720-
if key not in ["cross_attention_hidden_size"]:
6679+
if key in decoder_keys_to_include:
67216680
if isinstance(value, bool):
67226681
self.gguf_writer.add_key_value(f"decoder_{key}", value, gguf.GGUFValueType.BOOL)
67236682
elif isinstance(value, int):
@@ -6726,10 +6685,10 @@ def set_gguf_parameters(self):
67266685
self.gguf_writer.add_key_value(f"decoder_{key}", value, gguf.GGUFValueType.FLOAT32)
67276686
elif isinstance(value, str):
67286687
self.gguf_writer.add_key_value(f"decoder_{key}", value, gguf.GGUFValueType.STRING)
6729-
6688+
67306689
# T5 models typically use 32 relative attention buckets
67316690
self.gguf_writer.add_relative_attn_buckets_count(32)
6732-
6691+
67336692
self.gguf_writer.add_file_type(self.ftype)
67346693

67356694
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
@@ -6761,20 +6720,20 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
67616720
n_head_enc = self.hparams.get("encoder_num_attention_heads", 8)
67626721
n_head_dec = self.hparams.get("decoder_num_attention_heads", 8)
67636722
n_rel_attn_bkts = self.hparams.get("relative_buckets_count", 32)
6764-
6723+
67656724
# Generate relative attention bias for encoder layers
67666725
for i in range(self.block_count):
67676726
# Encoder relative attention bias - shape should be (n_rel_attn_bkts, n_head)
67686727
rel_bias_enc = torch.zeros(n_rel_attn_bkts, n_head_enc, dtype=torch.float16)
6769-
yield f"enc.blk.{i}.attn_rel_b.weight", rel_bias_enc
6770-
6728+
yield self.format_tensor_name(gguf.MODEL_TENSOR.ENC_ATTN_REL_B, i), rel_bias_enc
6729+
67716730
# Decoder relative attention bias - shape should be (n_rel_attn_bkts, n_head)
67726731
rel_bias_dec = torch.zeros(n_rel_attn_bkts, n_head_dec, dtype=torch.float16)
6773-
yield f"dec.blk.{i}.attn_rel_b.weight", rel_bias_dec
6774-
6732+
yield self.format_tensor_name(gguf.MODEL_TENSOR.DEC_ATTN_REL_B, i), rel_bias_dec
6733+
67756734
# Decoder cross attention relative bias - shape should be (n_rel_attn_bkts, n_head)
67766735
rel_bias_cross = torch.zeros(n_rel_attn_bkts, n_head_dec, dtype=torch.float16)
6777-
yield f"dec.blk.{i}.cross_attn_rel_b.weight", rel_bias_cross
6736+
yield self.format_tensor_name(gguf.MODEL_TENSOR.DEC_CROSS_ATTN_REL_B, i), rel_bias_cross
67786737

67796738

67806739
@ModelBase.register("T5EncoderModel")

gguf-py/gguf/tensor_mapping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -970,7 +970,7 @@ class TensorNameMap:
970970
"decoder.final_layer_norm", # t5
971971
"model.decoder.norm", # t5gemma
972972
),
973-
973+
974974
# T5GEMMA specific post layer normalization tensors
975975
MODEL_TENSOR.DEC_POST_SELF_ATTN_NORM: (
976976
"model.decoder.layers.{bid}.post_self_attn_layernorm", # t5gemma

0 commit comments

Comments
 (0)