Skip to content

Commit a9f77a8

Browse files
authored
server : add openai-style logit_bias support (#14946)
Signed-off-by: Lukas Straub <[email protected]>
1 parent 8a4a856 commit a9f77a8

File tree

4 files changed

+90
-1
lines changed

4 files changed

+90
-1
lines changed

tools/server/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ These words will not be included in the completion, so make sure to add them to
469469

470470
`ignore_eos`: Ignore end of stream token and continue generating. Default: `false`
471471

472-
`logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced. The tokens can also be represented as strings, e.g. `[["Hello, World!",-0.5]]` will reduce the likelihood of all the individual tokens that represent the string `Hello, World!`, just like the `presence_penalty` does. Default: `[]`
472+
`logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced. The tokens can also be represented as strings, e.g. `[["Hello, World!",-0.5]]` will reduce the likelihood of all the individual tokens that represent the string `Hello, World!`, just like the `presence_penalty` does. For compatibility with the OpenAI API, a JSON object {"<string or token id>": bias, ...} can also be passed. Default: `[]`
473473

474474
`n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token given the sampling settings. Note that for temperature < 0 the tokens are sampled greedily but token probabilities are still being calculated via a simple softmax of the logits without considering any other sampler settings. Default: `0`
475475

tools/server/server.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,33 @@ struct server_task {
473473
}
474474
}
475475
}
476+
} else if (logit_bias != data.end() && logit_bias->is_object()) {
477+
const int n_vocab = llama_vocab_n_tokens(vocab);
478+
for (const auto & el : logit_bias->items()) {
479+
float bias;
480+
const auto & key = el.key();
481+
const auto & value = el.value();
482+
if (value.is_number()) {
483+
bias = value.get<float>();
484+
} else if (value.is_boolean() && !value.get<bool>()) {
485+
bias = -INFINITY;
486+
} else {
487+
continue;
488+
}
489+
490+
char *end;
491+
llama_token tok = strtol(key.c_str(), &end, 10);
492+
if (*end == 0) {
493+
if (tok >= 0 && tok < n_vocab) {
494+
params.sampling.logit_bias.push_back({tok, bias});
495+
}
496+
} else {
497+
auto toks = common_tokenize(vocab, key, false);
498+
for (auto tok : toks) {
499+
params.sampling.logit_bias.push_back({tok, bias});
500+
}
501+
}
502+
}
476503
}
477504

478505
params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos);

tools/server/tests/unit/test_chat_completion.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,3 +351,32 @@ def test_logprobs_stream():
351351
assert token.top_logprobs is not None
352352
assert len(token.top_logprobs) > 0
353353
assert aggregated_text == output_text
354+
355+
356+
def test_logit_bias():
357+
global server
358+
server.start()
359+
360+
exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"]
361+
362+
res = server.make_request("POST", "/tokenize", data={
363+
"content": " " + " ".join(exclude) + " ",
364+
})
365+
assert res.status_code == 200
366+
tokens = res.body["tokens"]
367+
logit_bias = {tok: -100 for tok in tokens}
368+
369+
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
370+
res = client.chat.completions.create(
371+
model="gpt-3.5-turbo-instruct",
372+
temperature=0.0,
373+
messages=[
374+
{"role": "system", "content": "Book"},
375+
{"role": "user", "content": "What is the best book"},
376+
],
377+
max_tokens=64,
378+
logit_bias=logit_bias
379+
)
380+
output_text = res.choices[0].message.content
381+
assert output_text
382+
assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)

tools/server/tests/unit/test_completion.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,39 @@ def test_n_probs_post_sampling():
444444
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
445445

446446

447+
@pytest.mark.parametrize("tokenize,openai_style", [(False, False), (False, True), (True, False), (True, True)])
448+
def test_logit_bias(tokenize, openai_style):
449+
global server
450+
server.start()
451+
452+
exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"]
453+
454+
logit_bias = []
455+
if tokenize:
456+
res = server.make_request("POST", "/tokenize", data={
457+
"content": " " + " ".join(exclude) + " ",
458+
})
459+
assert res.status_code == 200
460+
tokens = res.body["tokens"]
461+
logit_bias = [[tok, -100] for tok in tokens]
462+
463+
else:
464+
logit_bias = [[" " + tok + " ", -100] for tok in exclude]
465+
466+
if openai_style:
467+
logit_bias = {el[0]: -100 for el in logit_bias}
468+
469+
res = server.make_request("POST", "/completion", data={
470+
"n_predict": 64,
471+
"prompt": "What is the best book",
472+
"logit_bias": logit_bias,
473+
"temperature": 0.0
474+
})
475+
assert res.status_code == 200
476+
output_text = res.body["content"]
477+
assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)
478+
479+
447480
def test_cancel_request():
448481
global server
449482
server.n_ctx = 4096

0 commit comments

Comments
 (0)