@@ -2853,29 +2853,47 @@ def write_tensors(self):
28532853 raise ValueError (f"Unprocessed experts: { experts } " )
28542854
28552855
2856- @Model .register ("T5ForConditionalGeneration" )
28572856@Model .register ("T5WithLMHeadModel" )
2857+ @Model .register ("T5ForConditionalGeneration" )
2858+ @Model .register ("MT5ForConditionalGeneration" )
2859+ @Model .register ("UMT5ForConditionalGeneration" )
28582860class T5Model (Model ):
28592861 model_arch = gguf .MODEL_ARCH .T5
28602862
2863+ def __init__ (self , * args , ** kwargs ):
2864+ super ().__init__ (* args , ** kwargs )
2865+ self .shared_token_embeddings_found = False
2866+
28612867 def set_vocab (self ):
28622868 # to avoid TypeError: Descriptors cannot be created directly
28632869 # exception when importing sentencepiece_model_pb2
28642870 os .environ ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION" ] = "python"
28652871 from sentencepiece import SentencePieceProcessor
28662872 from sentencepiece import sentencepiece_model_pb2 as model
28672873
2868- tokenizer_path = self .dir_model / 'spiece.model'
2874+ tokenizer_path = self .dir_model / 'tokenizer.model'
2875+
2876+ # many older models use spiece.model tokenizer model filename
2877+ if not tokenizer_path .is_file ():
2878+ tokenizer_path = self .dir_model / 'spiece.model'
28692879
28702880 if not tokenizer_path .is_file ():
28712881 raise FileNotFoundError (f"File not found: { tokenizer_path } " )
28722882
28732883 sentencepiece_model = model .ModelProto ()
28742884 sentencepiece_model .ParseFromString (open (tokenizer_path , "rb" ).read ())
2885+
2886+ # some models like Pile-T5 family use BPE tokenizer instead of Unigram
2887+ if sentencepiece_model .trainer_spec .model_type == 2 : # BPE
2888+ # assure the tokenizer model file name is correct
2889+ assert tokenizer_path .name == 'tokenizer.model'
2890+ return self ._set_vocab_sentencepiece ()
2891+ else :
2892+ assert sentencepiece_model .trainer_spec .model_type == 1 # UNIGRAM
2893+
28752894 add_prefix = sentencepiece_model .normalizer_spec .add_dummy_prefix
28762895 remove_whitespaces = sentencepiece_model .normalizer_spec .remove_extra_whitespaces
28772896 precompiled_charsmap = sentencepiece_model .normalizer_spec .precompiled_charsmap
2878- assert sentencepiece_model .trainer_spec .model_type == 1 # UNIGRAM
28792897
28802898 tokenizer = SentencePieceProcessor ()
28812899 tokenizer .LoadFromFile (str (tokenizer_path ))
@@ -2945,7 +2963,10 @@ def set_vocab(self):
29452963
29462964 def set_gguf_parameters (self ):
29472965 self .gguf_writer .add_name ("T5" )
2948- self .gguf_writer .add_context_length (self .hparams ["n_positions" ])
2966+ if (n_ctx := self .find_hparam (["n_positions" ], optional = True )) is None :
2967+ logger .warning ("Couldn't find context length in config.json, assuming default value of 512" )
2968+ n_ctx = 512
2969+ self .gguf_writer .add_context_length (n_ctx )
29492970 self .gguf_writer .add_embedding_length (self .hparams ["d_model" ])
29502971 self .gguf_writer .add_feed_forward_length (self .hparams ["d_ff" ])
29512972 self .gguf_writer .add_block_count (self .hparams ["num_layers" ])
@@ -2961,12 +2982,17 @@ def set_gguf_parameters(self):
29612982 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
29622983 del bid # unused
29632984
2964- # Sometimes T5 and Flan-T5 based models contain "encoder.embed_tokens.weight" tensor or
2965- # "decoder.embed_tokens.weight" tensors that are duplicates of "shared.weight" tensor
2966- # To prevent errors caused by an unnecessary unmapped tensor, skip both of them and use only "shared.weight".
2967- if name == "decoder.embed_tokens.weight" or name == "encoder.embed_tokens.weight" :
2968- logger .debug (f"Skipping tensor { name !r} in safetensors so that convert can end normally." )
2969- return []
2985+ # T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight",
2986+ # "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored
2987+ # in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder
2988+ # and decoder and ignore the remaining ones.
2989+ if name in ["decoder.embed_tokens.weight" , "encoder.embed_tokens.weight" , "shared.weight" ]:
2990+ if not self .shared_token_embeddings_found :
2991+ name = "shared.weight"
2992+ self .shared_token_embeddings_found = True
2993+ else :
2994+ logger .debug (f"Skipping shared tensor { name !r} in safetensors so that convert can end normally." )
2995+ return []
29702996
29712997 return [(self .map_tensor_name (name ), data_torch )]
29722998
0 commit comments