Skip to content

Commit a854897

Browse files
authored
set eos to sep if missing
ggml-ci
1 parent f172a27 commit a854897

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

gguf-py/gguf/vocab.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ def _try_load_from_tokenizer_json(self, path: Path) -> bool:
160160
special_cls = (tokenizer_config or {}).get('cls_token')
161161
special_eos = (tokenizer_config or {}).get('eos_token')
162162
special_sep = (tokenizer_config or {}).get('sep_token')
163+
if not special_eos and special_sep and tokenizer_config:
164+
tokenizer_config['eos_token'] = special_eos = special_sep
163165
post_processor = tokenizer.get('post_processor', {})
164166
for processor in post_processor.get('processors', [post_processor]):
165167
if processor.get('type') == 'RobertaProcessing':
@@ -192,12 +194,12 @@ def _try_load_from_tokenizer_json(self, path: Path) -> bool:
192194
special_eos = special_last
193195
self.add_special_token['eos'] = True if special_last == special_eos else False
194196
if special_last != special_eos:
195-
logger.warning(f'Unknown trailing special token {special_first!r} in TemplateProcessing<single>')
197+
logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
196198
if tmpl_pair:
197199
seq_start = 1 if tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
198200
seq_stop = -1 if tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
199-
if seq_start == 0 or seq_stop == None:
200-
logger.warning(f'TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
201+
if seq_start == 0 or seq_stop is None:
202+
logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
201203
if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
202204
tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
203205
tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')

0 commit comments

Comments
 (0)