@@ -2506,6 +2506,112 @@ def set_gguf_parameters(self):
25062506 self .gguf_writer .add_rope_freq_base (self .hparams ["rotary_emb_base" ])
25072507
25082508
2509+ @Model .register ("XLMRobertaModel" )
2510+ class XLMRobertaModel (BertModel ):
2511+ model_arch = gguf .MODEL_ARCH .BERT
2512+
2513+ def __init__ (self , * args , ** kwargs ):
2514+ super ().__init__ (* args , ** kwargs )
2515+
2516+ # we need the pad_token_id to know how to chop down position_embd matrix
2517+ if (pad_token_id := self .hparams .get ("pad_token_id" )) is not None :
2518+ self ._position_offset = 1 + pad_token_id
2519+ if "max_position_embeddings" in self .hparams :
2520+ self .hparams ["max_position_embeddings" ] -= self ._position_offset
2521+ else :
2522+ self ._position_offset = None
2523+
2524+ def set_vocab (self ):
2525+ # to avoid TypeError: Descriptors cannot be created directly
2526+ # exception when importing sentencepiece_model_pb2
2527+ os .environ ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION" ] = "python"
2528+ from sentencepiece import SentencePieceProcessor
2529+ from sentencepiece import sentencepiece_model_pb2 as model
2530+
2531+ tokenizer_path = self .dir_model / 'sentencepiece.bpe.model'
2532+ if not tokenizer_path .is_file ():
2533+ raise FileNotFoundError (f"File not found: { tokenizer_path } " )
2534+
2535+ sentencepiece_model = model .ModelProto () # pyright: ignore[reportAttributeAccessIssue]
2536+ sentencepiece_model .ParseFromString (open (tokenizer_path , "rb" ).read ())
2537+ assert sentencepiece_model .trainer_spec .model_type == 1 # UNIGRAM
2538+
2539+ add_prefix = sentencepiece_model .normalizer_spec .add_dummy_prefix
2540+ remove_whitespaces = sentencepiece_model .normalizer_spec .remove_extra_whitespaces
2541+ precompiled_charsmap = sentencepiece_model .normalizer_spec .precompiled_charsmap
2542+
2543+ tokenizer = SentencePieceProcessor ()
2544+ tokenizer .LoadFromFile (str (tokenizer_path ))
2545+
2546+ vocab_size = self .hparams .get ('vocab_size' , tokenizer .vocab_size ())
2547+
2548+ tokens : list [bytes ] = [f"[PAD{ i } ]" .encode ("utf-8" ) for i in range (vocab_size )]
2549+ scores : list [float ] = [- 10000.0 ] * vocab_size
2550+ toktypes : list [int ] = [SentencePieceTokenTypes .UNUSED ] * vocab_size
2551+
2552+ for token_id in range (tokenizer .vocab_size ()):
2553+ piece = tokenizer .IdToPiece (token_id )
2554+ text = piece .encode ("utf-8" )
2555+ score = tokenizer .GetScore (token_id )
2556+
2557+ toktype = SentencePieceTokenTypes .NORMAL
2558+ if tokenizer .IsUnknown (token_id ):
2559+ toktype = SentencePieceTokenTypes .UNKNOWN
2560+ elif tokenizer .IsControl (token_id ):
2561+ toktype = SentencePieceTokenTypes .CONTROL
2562+ elif tokenizer .IsUnused (token_id ):
2563+ toktype = SentencePieceTokenTypes .UNUSED
2564+ elif tokenizer .IsByte (token_id ):
2565+ toktype = SentencePieceTokenTypes .BYTE
2566+
2567+ tokens [token_id ] = text
2568+ scores [token_id ] = score
2569+ toktypes [token_id ] = toktype
2570+
2571+ if vocab_size > len (tokens ):
2572+ pad_count = vocab_size - len (tokens )
2573+ logger .debug (f"Padding vocab with { pad_count } token(s) - [PAD1] through [PAD{ pad_count } ]" )
2574+ for i in range (1 , pad_count + 1 ):
2575+ tokens .append (bytes (f"[PAD{ i } ]" , encoding = "utf-8" ))
2576+ scores .append (- 1000.0 )
2577+ toktypes .append (SentencePieceTokenTypes .UNUSED )
2578+
2579+ # realign tokens (see HF tokenizer code)
2580+ tokens = [b'<s>' , b'<pad>' , b'</s>' , b'<unk>' ] + tokens [3 :- 1 ]
2581+ scores = [0.0 , 0.0 , 0.0 , 0.0 ] + scores [3 :- 1 ]
2582+ toktypes = [
2583+ SentencePieceTokenTypes .CONTROL ,
2584+ SentencePieceTokenTypes .CONTROL ,
2585+ SentencePieceTokenTypes .CONTROL ,
2586+ SentencePieceTokenTypes .UNKNOWN ,
2587+ ] + toktypes [3 :- 1 ]
2588+
2589+ self .gguf_writer .add_tokenizer_model ("t5" )
2590+ self .gguf_writer .add_tokenizer_pre ("default" )
2591+ self .gguf_writer .add_token_list (tokens )
2592+ self .gguf_writer .add_token_scores (scores )
2593+ self .gguf_writer .add_token_types (toktypes )
2594+ self .gguf_writer .add_add_space_prefix (add_prefix )
2595+ self .gguf_writer .add_token_type_count (1 )
2596+ self .gguf_writer .add_remove_extra_whitespaces (remove_whitespaces )
2597+ if precompiled_charsmap :
2598+ self .gguf_writer .add_precompiled_charsmap (precompiled_charsmap )
2599+
2600+ special_vocab = gguf .SpecialVocab (self .dir_model , n_vocab = len (tokens ))
2601+ special_vocab .add_to_gguf (self .gguf_writer )
2602+
2603+ self .gguf_writer .add_add_bos_token (True )
2604+ self .gguf_writer .add_add_eos_token (True )
2605+
2606+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
2607+ # position embeddings start at pad_token_id + 1, so just chop down the weight tensor
2608+ if name == "embeddings.position_embeddings.weight" :
2609+ if self ._position_offset is not None :
2610+ data_torch = data_torch [self ._position_offset :,:]
2611+
2612+ return super ().modify_tensors (data_torch , name , bid )
2613+
2614+
25092615@Model .register ("GemmaForCausalLM" )
25102616class GemmaModel (Model ):
25112617 model_arch = gguf .MODEL_ARCH .GEMMA
0 commit comments