@@ -167,81 +167,81 @@ def _try_load_from_tokenizer_json(self, path: Path) -> bool:
167167 tokenizer_config ['bos_token' ] = special_bos = special_cls
168168 if not special_eos and special_sep and tokenizer_config :
169169 tokenizer_config ['eos_token' ] = special_eos = special_sep
170- post_processor = tokenizer .get ('post_processor' , {})
171- for processor in post_processor .get ('processors' , [post_processor ]):
172- if processor .get ('type' ) == 'RobertaProcessing' :
173- self .add_special_token ['bos' ] = True
174- self .add_special_token ['eos' ] = True
175- self .add_special_token ['sep' ] = True
176- if not special_cls and tokenizer_config :
177- special_cls = processor .get ('cls' , [special_bos ])[0 ]
178- tokenizer_config ['cls_token' ] = special_cls
179- if not special_sep and tokenizer_config :
180- special_sep = processor .get ('sep' , [special_eos ])[0 ]
181- tokenizer_config ['sep_token' ] = special_sep
182- continue
183- # Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
184- # Only works with simple templates, **will** get it wrong on unusual sequences
185- if processor .get ('type' ) == 'TemplateProcessing' :
186- tmpl_single = processor .get ('single' , [])
187- tmpl_pair = processor .get ('pair' , [])
188- special_first = None
189- special_last = None
190- if len (tmpl_single ) > 1 :
191- if special_first := tmpl_single [0 ].get ('SpecialToken' , {}).get ('id' ):
192- if not tokenizer_config :
193- special_bos = special_first
194- self .add_special_token ['bos' ] = True if special_first in (special_bos , special_cls ) else False
195- if special_first not in (special_bos , special_cls ):
196- logger .warning (f'Unknown leading special token { special_first !r} in TemplateProcessing<single>' )
197- if special_last := tmpl_single [- 1 ].get ('SpecialToken' , {}).get ('id' ):
198- if not tokenizer_config :
199- special_eos = special_last
200- elif special_last != special_eos :
201- if 'eot' not in self .special_token_types :
202- self .special_token_types = tuple (self .special_token_types ) + ('eot' , )
203- tokenizer_config ['eot_token' ] = special_eos
204- elif 'eom' not in self .special_token_types :
205- self .special_token_types = tuple (self .special_token_types ) + ('eom' , )
206- tokenizer_config ['eom_token' ] = special_eos
207- else :
208- logger .warning (f'Overriding EOS token { special_eos !r} with { special_last !r} without EOT/EOM fallback!' )
209- tokenizer_config ['eos_token' ] = special_eos = special_last
210- self .add_special_token ['eos' ] = True if special_last == special_eos else False
211- if special_last != special_eos :
212- logger .warning (f'Unknown trailing special token { special_last !r} in TemplateProcessing<single>' )
213- if tmpl_pair :
214- seq_start = 1 if special_first and tmpl_pair [0 ].get ('SpecialToken' , {}).get ('id' ) == special_first else 0
215- seq_stop = - 1 if special_last and tmpl_pair [- 1 ].get ('SpecialToken' , {}).get ('id' ) == special_last else None
216- if (special_first and seq_start == 0 ) or (special_last and seq_stop is None ):
217- logger .warning ('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>' )
218- if tmpl_pair := tmpl_pair [slice (seq_start , seq_stop )]:
219- tmpl_a = tmpl_pair [0 ].get ('Sequence' , {}).get ('id' )
220- tmpl_b = tmpl_pair [- 1 ].get ('Sequence' , {}).get ('id' )
221- if tmpl_a != 'A' or tmpl_b != 'B' :
222- logger .warning (f'Unknown sequence { tmpl_a } ...{ tmpl_b } in TemplateProcessing<pair>' )
223- # A [sep] [eos] B
224- if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair [1 :- 1 ]):
225- add_sep = False
226- if special_entry := tmpl_pair [0 ].get ('SpecialToken' , {}).get ('id' ):
227- if special_entry in (special_sep , special_eos ) and not special_last :
228- add_sep = True
229- if special_entry not in (special_sep , special_eos ):
230- logger .warning (f'Unknown separator token { special_entry !r} in TemplateProcessing<pair>' )
231- else :
232- logger .warning (f'Unknown middle sequence { tmpl_pair [0 ]!r} in TemplateProcessing<pair>' )
233- if len (tmpl_pair ) == 2 :
234- if special_entry := tmpl_pair [1 ].get ('SpecialToken' , {}).get ('id' ):
235- if special_entry in (special_sep , special_eos ):
170+ if post_processor := tokenizer .get ('post_processor' ):
171+ for processor in post_processor .get ('processors' , [post_processor ]):
172+ if processor .get ('type' ) == 'RobertaProcessing' :
173+ self .add_special_token ['bos' ] = True
174+ self .add_special_token ['eos' ] = True
175+ self .add_special_token ['sep' ] = True
176+ if not special_cls and tokenizer_config :
177+ special_cls = processor .get ('cls' , [special_bos ])[0 ]
178+ tokenizer_config ['cls_token' ] = special_cls
179+ if not special_sep and tokenizer_config :
180+ special_sep = processor .get ('sep' , [special_eos ])[0 ]
181+ tokenizer_config ['sep_token' ] = special_sep
182+ continue
183+ # Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
184+ # Only works with simple templates, **will** get it wrong on unusual sequences
185+ if processor .get ('type' ) == 'TemplateProcessing' :
186+ tmpl_single = processor .get ('single' , [])
187+ tmpl_pair = processor .get ('pair' , [])
188+ special_first = None
189+ special_last = None
190+ if len (tmpl_single ) > 1 :
191+ if special_first := tmpl_single [0 ].get ('SpecialToken' , {}).get ('id' ):
192+ if not tokenizer_config :
193+ special_bos = special_first
194+ self .add_special_token ['bos' ] = True if special_first in (special_bos , special_cls ) else False
195+ if special_first not in (special_bos , special_cls ):
196+ logger .warning (f'Unknown leading special token { special_first !r} in TemplateProcessing<single>' )
197+ if special_last := tmpl_single [- 1 ].get ('SpecialToken' , {}).get ('id' ):
198+ if not tokenizer_config :
199+ special_eos = special_last
200+ elif special_last != special_eos :
201+ if 'eot' not in self .special_token_types :
202+ self .special_token_types = tuple (self .special_token_types ) + ('eot' , )
203+ tokenizer_config ['eot_token' ] = special_eos
204+ elif 'eom' not in self .special_token_types :
205+ self .special_token_types = tuple (self .special_token_types ) + ('eom' , )
206+ tokenizer_config ['eom_token' ] = special_eos
207+ else :
208+ logger .warning (f'Overriding EOS token { special_eos !r} with { special_last !r} without EOT/EOM fallback!' )
209+ tokenizer_config ['eos_token' ] = special_eos = special_last
210+ self .add_special_token ['eos' ] = True if special_last == special_eos else False
211+ if special_last != special_eos :
212+ logger .warning (f'Unknown trailing special token { special_last !r} in TemplateProcessing<single>' )
213+ if tmpl_pair :
214+ seq_start = 1 if special_first and tmpl_pair [0 ].get ('SpecialToken' , {}).get ('id' ) == special_first else 0
215+ seq_stop = - 1 if special_last and tmpl_pair [- 1 ].get ('SpecialToken' , {}).get ('id' ) == special_last else None
216+ if (special_first and seq_start == 0 ) or (special_last and seq_stop is None ):
217+ logger .warning ('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>' )
218+ if tmpl_pair := tmpl_pair [slice (seq_start , seq_stop )]:
219+ tmpl_a = tmpl_pair [0 ].get ('Sequence' , {}).get ('id' )
220+ tmpl_b = tmpl_pair [- 1 ].get ('Sequence' , {}).get ('id' )
221+ if tmpl_a != 'A' or tmpl_b != 'B' :
222+ logger .warning (f'Unknown sequence { tmpl_a } ...{ tmpl_b } in TemplateProcessing<pair>' )
223+ # A [sep] [eos] B
224+ if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair [1 :- 1 ]):
225+ add_sep = False
226+ if special_entry := tmpl_pair [0 ].get ('SpecialToken' , {}).get ('id' ):
227+ if special_entry in (special_sep , special_eos ) and not special_last :
236228 add_sep = True
237229 if special_entry not in (special_sep , special_eos ):
238- logger .warning (f'Unknown second separator token { special_entry !r} in TemplateProcessing<pair>' )
230+ logger .warning (f'Unknown separator token { special_entry !r} in TemplateProcessing<pair>' )
239231 else :
240- logger .warning (f'Unknown second middle sequence { tmpl_pair [1 ]!r} in TemplateProcessing<pair>' )
241- self .add_special_token ['sep' ] = add_sep
242- if add_sep and not special_sep and tokenizer_config :
243- tokenizer_config ['sep_token' ] = special_eos
244- continue
232+ logger .warning (f'Unknown middle sequence { tmpl_pair [0 ]!r} in TemplateProcessing<pair>' )
233+ if len (tmpl_pair ) == 2 :
234+ if special_entry := tmpl_pair [1 ].get ('SpecialToken' , {}).get ('id' ):
235+ if special_entry in (special_sep , special_eos ):
236+ add_sep = True
237+ if special_entry not in (special_sep , special_eos ):
238+ logger .warning (f'Unknown second separator token { special_entry !r} in TemplateProcessing<pair>' )
239+ else :
240+ logger .warning (f'Unknown second middle sequence { tmpl_pair [1 ]!r} in TemplateProcessing<pair>' )
241+ self .add_special_token ['sep' ] = add_sep
242+ if add_sep and not special_sep and tokenizer_config :
243+ tokenizer_config ['sep_token' ] = special_eos
244+ continue
245245 if not tokenizer_config :
246246 return True
247247 chat_template_alt = None
0 commit comments