@@ -168,81 +168,81 @@ def _try_load_from_tokenizer_json(self, path: Path) -> bool:
168168 tokenizer_config ['bos_token' ] = special_bos = special_cls
169169 if not special_eos and special_sep and tokenizer_config :
170170 tokenizer_config ['eos_token' ] = special_eos = special_sep
171- post_processor = tokenizer .get ('post_processor' , {})
172- for processor in post_processor .get ('processors' , [post_processor ]):
173- if processor .get ('type' ) == 'RobertaProcessing' :
174- self .add_special_token ['bos' ] = True
175- self .add_special_token ['eos' ] = True
176- self .add_special_token ['sep' ] = True
177- if not special_cls and tokenizer_config :
178- special_cls = processor .get ('cls' , [special_bos ])[0 ]
179- tokenizer_config ['cls_token' ] = special_cls
180- if not special_sep and tokenizer_config :
181- special_sep = processor .get ('sep' , [special_eos ])[0 ]
182- tokenizer_config ['sep_token' ] = special_sep
183- continue
184- # Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
185- # Only works with simple templates, **will** get it wrong on unusual sequences
186- if processor .get ('type' ) == 'TemplateProcessing' :
187- tmpl_single = processor .get ('single' , [])
188- tmpl_pair = processor .get ('pair' , [])
189- special_first = None
190- special_last = None
191- if len (tmpl_single ) > 1 :
192- if special_first := tmpl_single [0 ].get ('SpecialToken' , {}).get ('id' ):
193- if not tokenizer_config :
194- special_bos = special_first
195- self .add_special_token ['bos' ] = True if special_first in (special_bos , special_cls ) else False
196- if special_first not in (special_bos , special_cls ):
197- logger .warning (f'Unknown leading special token { special_first !r} in TemplateProcessing<single>' )
198- if special_last := tmpl_single [- 1 ].get ('SpecialToken' , {}).get ('id' ):
199- if not tokenizer_config :
200- special_eos = special_last
201- elif special_last != special_eos :
202- if 'eot' not in self .special_token_types :
203- self .special_token_types = tuple (self .special_token_types ) + ('eot' , )
204- tokenizer_config ['eot_token' ] = special_eos
205- elif 'eom' not in self .special_token_types :
206- self .special_token_types = tuple (self .special_token_types ) + ('eom' , )
207- tokenizer_config ['eom_token' ] = special_eos
208- else :
209- logger .warning (f'Overriding EOS token { special_eos !r} with { special_last !r} without EOT/EOM fallback!' )
210- tokenizer_config ['eos_token' ] = special_eos = special_last
211- self .add_special_token ['eos' ] = True if special_last == special_eos else False
212- if special_last != special_eos :
213- logger .warning (f'Unknown trailing special token { special_last !r} in TemplateProcessing<single>' )
214- if tmpl_pair :
215- seq_start = 1 if special_first and tmpl_pair [0 ].get ('SpecialToken' , {}).get ('id' ) == special_first else 0
216- seq_stop = - 1 if special_last and tmpl_pair [- 1 ].get ('SpecialToken' , {}).get ('id' ) == special_last else None
217- if (special_first and seq_start == 0 ) or (special_last and seq_stop is None ):
218- logger .warning ('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>' )
219- if tmpl_pair := tmpl_pair [slice (seq_start , seq_stop )]:
220- tmpl_a = tmpl_pair [0 ].get ('Sequence' , {}).get ('id' )
221- tmpl_b = tmpl_pair [- 1 ].get ('Sequence' , {}).get ('id' )
222- if tmpl_a != 'A' or tmpl_b != 'B' :
223- logger .warning (f'Unknown sequence { tmpl_a } ...{ tmpl_b } in TemplateProcessing<pair>' )
224- # A [sep] [eos] B
225- if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair [1 :- 1 ]):
226- add_sep = False
227- if special_entry := tmpl_pair [0 ].get ('SpecialToken' , {}).get ('id' ):
228- if special_entry in (special_sep , special_eos ) and not special_last :
229- add_sep = True
230- if special_entry not in (special_sep , special_eos ):
231- logger .warning (f'Unknown separator token { special_entry !r} in TemplateProcessing<pair>' )
232- else :
233- logger .warning (f'Unknown middle sequence { tmpl_pair [0 ]!r} in TemplateProcessing<pair>' )
234- if len (tmpl_pair ) == 2 :
235- if special_entry := tmpl_pair [1 ].get ('SpecialToken' , {}).get ('id' ):
236- if special_entry in (special_sep , special_eos ):
171+ if post_processor := tokenizer .get ('post_processor' ):
172+ for processor in post_processor .get ('processors' , [post_processor ]):
173+ if processor .get ('type' ) == 'RobertaProcessing' :
174+ self .add_special_token ['bos' ] = True
175+ self .add_special_token ['eos' ] = True
176+ self .add_special_token ['sep' ] = True
177+ if not special_cls and tokenizer_config :
178+ special_cls = processor .get ('cls' , [special_bos ])[0 ]
179+ tokenizer_config ['cls_token' ] = special_cls
180+ if not special_sep and tokenizer_config :
181+ special_sep = processor .get ('sep' , [special_eos ])[0 ]
182+ tokenizer_config ['sep_token' ] = special_sep
183+ continue
184+ # Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
185+ # Only works with simple templates, **will** get it wrong on unusual sequences
186+ if processor .get ('type' ) == 'TemplateProcessing' :
187+ tmpl_single = processor .get ('single' , [])
188+ tmpl_pair = processor .get ('pair' , [])
189+ special_first = None
190+ special_last = None
191+ if len (tmpl_single ) > 1 :
192+ if special_first := tmpl_single [0 ].get ('SpecialToken' , {}).get ('id' ):
193+ if not tokenizer_config :
194+ special_bos = special_first
195+ self .add_special_token ['bos' ] = True if special_first in (special_bos , special_cls ) else False
196+ if special_first not in (special_bos , special_cls ):
197+ logger .warning (f'Unknown leading special token { special_first !r} in TemplateProcessing<single>' )
198+ if special_last := tmpl_single [- 1 ].get ('SpecialToken' , {}).get ('id' ):
199+ if not tokenizer_config :
200+ special_eos = special_last
201+ elif special_last != special_eos :
202+ if 'eot' not in self .special_token_types :
203+ self .special_token_types = tuple (self .special_token_types ) + ('eot' , )
204+ tokenizer_config ['eot_token' ] = special_eos
205+ elif 'eom' not in self .special_token_types :
206+ self .special_token_types = tuple (self .special_token_types ) + ('eom' , )
207+ tokenizer_config ['eom_token' ] = special_eos
208+ else :
209+ logger .warning (f'Overriding EOS token { special_eos !r} with { special_last !r} without EOT/EOM fallback!' )
210+ tokenizer_config ['eos_token' ] = special_eos = special_last
211+ self .add_special_token ['eos' ] = True if special_last == special_eos else False
212+ if special_last != special_eos :
213+ logger .warning (f'Unknown trailing special token { special_last !r} in TemplateProcessing<single>' )
214+ if tmpl_pair :
215+ seq_start = 1 if special_first and tmpl_pair [0 ].get ('SpecialToken' , {}).get ('id' ) == special_first else 0
216+ seq_stop = - 1 if special_last and tmpl_pair [- 1 ].get ('SpecialToken' , {}).get ('id' ) == special_last else None
217+ if (special_first and seq_start == 0 ) or (special_last and seq_stop is None ):
218+ logger .warning ('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>' )
219+ if tmpl_pair := tmpl_pair [slice (seq_start , seq_stop )]:
220+ tmpl_a = tmpl_pair [0 ].get ('Sequence' , {}).get ('id' )
221+ tmpl_b = tmpl_pair [- 1 ].get ('Sequence' , {}).get ('id' )
222+ if tmpl_a != 'A' or tmpl_b != 'B' :
223+ logger .warning (f'Unknown sequence { tmpl_a } ...{ tmpl_b } in TemplateProcessing<pair>' )
224+ # A [sep] [eos] B
225+ if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair [1 :- 1 ]):
226+ add_sep = False
227+ if special_entry := tmpl_pair [0 ].get ('SpecialToken' , {}).get ('id' ):
228+ if special_entry in (special_sep , special_eos ) and not special_last :
237229 add_sep = True
238230 if special_entry not in (special_sep , special_eos ):
239- logger .warning (f'Unknown second separator token { special_entry !r} in TemplateProcessing<pair>' )
231+ logger .warning (f'Unknown separator token { special_entry !r} in TemplateProcessing<pair>' )
240232 else :
241- logger .warning (f'Unknown second middle sequence { tmpl_pair [1 ]!r} in TemplateProcessing<pair>' )
242- self .add_special_token ['sep' ] = add_sep
243- if add_sep and not special_sep and tokenizer_config :
244- tokenizer_config ['sep_token' ] = special_eos
245- continue
233+ logger .warning (f'Unknown middle sequence { tmpl_pair [0 ]!r} in TemplateProcessing<pair>' )
234+ if len (tmpl_pair ) == 2 :
235+ if special_entry := tmpl_pair [1 ].get ('SpecialToken' , {}).get ('id' ):
236+ if special_entry in (special_sep , special_eos ):
237+ add_sep = True
238+ if special_entry not in (special_sep , special_eos ):
239+ logger .warning (f'Unknown second separator token { special_entry !r} in TemplateProcessing<pair>' )
240+ else :
241+ logger .warning (f'Unknown second middle sequence { tmpl_pair [1 ]!r} in TemplateProcessing<pair>' )
242+ self .add_special_token ['sep' ] = add_sep
243+ if add_sep and not special_sep and tokenizer_config :
244+ tokenizer_config ['sep_token' ] = special_eos
245+ continue
246246 if not tokenizer_config :
247247 return True
248248 chat_template_alt = None
0 commit comments