Skip to content

Commit 3fedcaa

Browse files
CISCNexesenex
authored andcommitted
llama : improve sep token handling (ggml-org#14272)
1 parent 5598a14 commit 3fedcaa

File tree

14 files changed

+160
-28
lines changed

14 files changed

+160
-28
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2709,6 +2709,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
27092709
params.embd_sep = value;
27102710
}
27112711
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
2712+
add_opt(common_arg(
2713+
{"--cls-separator"}, "STRING",
2714+
"separator of classification sequences (default \\t) for example \"<#seq#>\"",
2715+
[](common_params & params, const std::string & value) {
2716+
params.cls_sep = value;
2717+
}
2718+
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
27122719
add_opt(common_arg(
27132720
{"--host"}, "HOST",
27142721
string_format("ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: %s)", params.hostname.c_str()),

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ struct common_params {
354354
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
355355
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
356356
std::string embd_sep = "\n"; // separator of embeddings
357+
std::string cls_sep = "\t"; // separator of classification sequences
357358

358359
// server params
359360
int32_t port = 8080; // server listens on this network port

convert_hf_to_gguf.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2474,7 +2474,6 @@ def __init__(self, *args, **kwargs):
24742474

24752475
def set_vocab(self):
24762476
self._set_vocab_gpt2()
2477-
self.gguf_writer.add_add_bos_token(True)
24782477

24792478
def set_gguf_parameters(self):
24802479
super().set_gguf_parameters()
@@ -4247,9 +4246,6 @@ def _xlmroberta_set_vocab(self) -> None:
42474246
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
42484247
special_vocab.add_to_gguf(self.gguf_writer)
42494248

4250-
self.gguf_writer.add_add_bos_token(True)
4251-
self.gguf_writer.add_add_eos_token(True)
4252-
42534249

42544250
@ModelBase.register("DistilBertModel", "DistilBertForMaskedLM", "DistilBertForSequenceClassification")
42554251
class DistilBertModel(BertModel):
@@ -4291,8 +4287,6 @@ def set_vocab(self):
42914287
bpe_tok_path = self.dir_model / "tokenizer.json"
42924288
if bpe_tok_path.exists():
42934289
self._set_vocab_gpt2()
4294-
self.gguf_writer.add_add_bos_token(True)
4295-
self.gguf_writer.add_add_eos_token(True)
42964290

42974291
# we need this to validate the size of the token_type embeddings
42984292
# though currently we are passing all zeros to the token_type embeddings
@@ -5177,8 +5171,6 @@ def set_vocab(self):
51775171
self.gguf_writer.add_token_type_count(2)
51785172
else:
51795173
raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel')
5180-
self.gguf_writer.add_add_bos_token(True)
5181-
self.gguf_writer.add_add_eos_token(True)
51825174

51835175

51845176
@ModelBase.register("OpenELMForCausalLM")
@@ -5780,9 +5772,6 @@ def set_vocab(self):
57805772
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
57815773
special_vocab.add_to_gguf(self.gguf_writer)
57825774

5783-
self.gguf_writer.add_add_bos_token(False)
5784-
self.gguf_writer.add_add_eos_token(True)
5785-
57865775
def set_gguf_parameters(self):
57875776
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
57885777
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
@@ -5920,9 +5909,6 @@ def set_vocab(self):
59205909
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
59215910
special_vocab.add_to_gguf(self.gguf_writer)
59225911

5923-
self.gguf_writer.add_add_bos_token(False)
5924-
self.gguf_writer.add_add_eos_token(True)
5925-
59265912
def set_gguf_parameters(self):
59275913
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
59285914
logger.warning("Couldn't find context length in config.json, assuming default value of 512")

examples/embedding/embedding.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,36 @@ int main(int argc, char ** argv) {
133133
// max batch size
134134
const uint64_t n_batch = params.n_batch;
135135

136+
// get added sep and eos token, if any
137+
const std::string added_sep_token = llama_vocab_get_add_sep(vocab) ? llama_vocab_get_text(vocab, llama_vocab_sep(vocab)) : "";
138+
const std::string added_eos_token = llama_vocab_get_add_eos(vocab) ? llama_vocab_get_text(vocab, llama_vocab_eos(vocab)) : "";
139+
136140
// tokenize the prompts and trim
137141
std::vector<std::vector<int32_t>> inputs;
138142
for (const auto & prompt : prompts) {
139-
auto inp = common_tokenize(ctx, prompt, true, true);
143+
std::vector<llama_token> inp;
144+
145+
// split classification pairs and insert expected separator tokens
146+
if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) {
147+
std::vector<std::string> pairs = split_lines(prompt, params.cls_sep);
148+
std::string final_prompt;
149+
150+
for (size_t i = 0; i < pairs.size(); i++) {
151+
final_prompt += pairs[i];
152+
if (i != pairs.size() - 1) {
153+
if (!added_eos_token.empty()) {
154+
final_prompt += added_eos_token;
155+
}
156+
if (!added_sep_token.empty()) {
157+
final_prompt += added_sep_token;
158+
}
159+
}
160+
}
161+
162+
inp = common_tokenize(ctx, final_prompt, true, true);
163+
} else {
164+
inp = common_tokenize(ctx, prompt, true, true);
165+
}
140166
if (inp.size() > n_batch) {
141167
LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n",
142168
__func__, (long long int) inp.size(), (long long int) n_batch);
@@ -145,11 +171,11 @@ int main(int argc, char ** argv) {
145171
inputs.push_back(inp);
146172
}
147173

148-
// check if the last token is SEP
174+
// check if the last token is SEP/EOS
149175
// it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true'
150176
for (auto & inp : inputs) {
151-
if (inp.empty() || inp.back() != llama_vocab_sep(vocab)) {
152-
LOG_WRN("%s: last token in the prompt is not SEP\n", __func__);
177+
if (inp.empty() || (inp.back() != llama_vocab_sep(vocab) && inp.back() != llama_vocab_eos(vocab))) {
178+
LOG_WRN("%s: last token in the prompt is not SEP or EOS\n", __func__);
153179
LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__);
154180
}
155181
}

gguf-py/gguf/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ class Tokenizer:
199199
MASK_ID = "tokenizer.ggml.mask_token_id"
200200
ADD_BOS = "tokenizer.ggml.add_bos_token"
201201
ADD_EOS = "tokenizer.ggml.add_eos_token"
202+
ADD_SEP = "tokenizer.ggml.add_sep_token"
202203
ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
203204
REMOVE_EXTRA_WS = "tokenizer.ggml.remove_extra_whitespaces"
204205
PRECOMPILED_CHARSMAP = "tokenizer.ggml.precompiled_charsmap"

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,9 @@ def add_add_bos_token(self, value: bool) -> None:
987987
def add_add_eos_token(self, value: bool) -> None:
988988
self.add_bool(Keys.Tokenizer.ADD_EOS, value)
989989

990+
def add_add_sep_token(self, value: bool) -> None:
991+
self.add_bool(Keys.Tokenizer.ADD_SEP, value)
992+
990993
def add_add_space_prefix(self, value: bool) -> None:
991994
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
992995

gguf-py/gguf/vocab.py

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def _set_special_token(self, typ: str, tid: Any) -> None:
119119
logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping')
120120

121121
def _try_load_from_tokenizer_json(self, path: Path) -> bool:
122+
tokenizer = None
122123
tokenizer_file = path / 'tokenizer.json'
123124
if tokenizer_file.is_file():
124125
with open(tokenizer_file, encoding = 'utf-8') as f:
@@ -152,11 +153,87 @@ def _try_load_from_tokenizer_json(self, path: Path) -> bool:
152153
added_tokens = tokenizer.get('added_tokens', {})
153154
else:
154155
added_tokens = {}
156+
tokenizer_config = None
155157
tokenizer_config_file = path / 'tokenizer_config.json'
156-
if not tokenizer_config_file.is_file():
158+
if tokenizer_config_file.is_file():
159+
with open(tokenizer_config_file, encoding = 'utf-8') as f:
160+
tokenizer_config = json.load(f)
161+
if tokenizer:
162+
special_bos = (tokenizer_config or {}).get('bos_token')
163+
special_cls = (tokenizer_config or {}).get('cls_token')
164+
special_eos = (tokenizer_config or {}).get('eos_token')
165+
special_sep = (tokenizer_config or {}).get('sep_token')
166+
if not special_bos and special_cls and tokenizer_config:
167+
tokenizer_config['bos_token'] = special_bos = special_cls
168+
if not special_eos and special_sep and tokenizer_config:
169+
tokenizer_config['eos_token'] = special_eos = special_sep
170+
post_processor = tokenizer.get('post_processor', {})
171+
for processor in post_processor.get('processors', [post_processor]):
172+
if processor.get('type') == 'RobertaProcessing':
173+
self.add_special_token['bos'] = True
174+
self.add_special_token['eos'] = True
175+
self.add_special_token['sep'] = True
176+
if not special_cls and tokenizer_config:
177+
special_cls = processor.get('cls', [special_bos])[0]
178+
tokenizer_config['cls_token'] = special_cls
179+
if not special_sep and tokenizer_config:
180+
special_sep = processor.get('sep', [special_eos])[0]
181+
tokenizer_config['sep_token'] = special_sep
182+
continue
183+
# Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
184+
# Only works with simple templates, **will** get it wrong on unusual sequences
185+
if processor.get('type') == 'TemplateProcessing':
186+
tmpl_single = processor.get('single', [])
187+
tmpl_pair = processor.get('pair', [])
188+
special_first = None
189+
special_last = None
190+
if len(tmpl_single) > 1:
191+
if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
192+
if not tokenizer_config:
193+
special_bos = special_first
194+
self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
195+
if special_first not in (special_bos, special_cls):
196+
logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
197+
if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
198+
if not tokenizer_config:
199+
special_eos = special_last
200+
self.add_special_token['eos'] = True if special_last == special_eos else False
201+
if special_last != special_eos:
202+
logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
203+
if tmpl_pair:
204+
seq_start = 1 if tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
205+
seq_stop = -1 if tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
206+
if seq_start == 0 or seq_stop is None:
207+
logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
208+
if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
209+
tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
210+
tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
211+
if tmpl_a != 'A' or tmpl_b != 'B':
212+
logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
213+
# A [sep] [eos] B
214+
if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
215+
add_sep = False
216+
if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
217+
if special_entry in (special_sep, special_eos) and not special_last:
218+
add_sep = True
219+
if special_entry not in (special_sep, special_eos):
220+
logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
221+
else:
222+
logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
223+
if len(tmpl_pair) == 2:
224+
if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
225+
if special_entry in (special_sep, special_eos):
226+
add_sep = True
227+
if special_entry not in (special_sep, special_eos):
228+
logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
229+
else:
230+
logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
231+
self.add_special_token['sep'] = add_sep
232+
if add_sep and not special_sep and tokenizer_config:
233+
tokenizer_config['sep_token'] = special_eos
234+
continue
235+
if not tokenizer_config:
157236
return True
158-
with open(tokenizer_config_file, encoding = 'utf-8') as f:
159-
tokenizer_config = json.load(f)
160237
chat_template_alt = None
161238
chat_template_file = path / 'chat_template.json'
162239
if chat_template_file.is_file():

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,7 @@ extern "C" {
11051105

11061106
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
11071107
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
1108+
LLAMA_API bool llama_vocab_get_add_sep(const struct llama_vocab * vocab);
11081109

11091110
LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab);
11101111
LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab);

src/llama-arch.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
198198
{ LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" },
199199
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
200200
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
201+
{ LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" },
201202
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
202203
{ LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" },
203204
{ LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ enum llm_kv {
194194
LLM_KV_TOKENIZER_MASK_ID,
195195
LLM_KV_TOKENIZER_ADD_BOS,
196196
LLM_KV_TOKENIZER_ADD_EOS,
197+
LLM_KV_TOKENIZER_ADD_SEP,
197198
LLM_KV_TOKENIZER_ADD_PREFIX,
198199
LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
199200
LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,

0 commit comments

Comments
 (0)