Skip to content

Commit 01afafe

Browse files
committed
add std::log
1 parent 7828013 commit 01afafe

File tree

3 files changed

+13
-13
lines changed

3 files changed

+13
-13
lines changed

examples/server/server.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -342,11 +342,6 @@ struct server_task {
342342
}
343343
}
344344

345-
if (params.sampling.n_probs > 0 && params.cache_prompt) {
346-
SRV_WRN("cache_prompt is not compatible with n_probs > 0 (current value = %d), disabling cache_prompt.\n", params.sampling.n_probs);
347-
params.cache_prompt = false;
348-
}
349-
350345
std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
351346
params.oaicompat_model = json_value(data, "model", model_name);
352347

@@ -439,7 +434,7 @@ struct completion_token_output {
439434
{"id", p.tok},
440435
{"token", tok_str},
441436
{"bytes", str_to_bytes(p.tok_str)},
442-
{"logprob", p.prob},
437+
{"logprob", logarithm(p.prob)},
443438
});
444439
}
445440
return probs_for_token;
@@ -453,14 +448,19 @@ struct completion_token_output {
453448
out.push_back(json {
454449
{"id", it.tok},
455450
{"token", tok_str},
456-
{"logprob", it.prob},
451+
{"logprob", logarithm(it.prob)},
457452
{"bytes", str_to_bytes(it.text_to_send)},
458453
{"top_logprobs", it.to_json()},
459454
});
460455
}
461456
return out;
462457
}
463458

459+
static float logarithm(float x) {
460+
// nlohmann::json converts -inf to null, so we need to prevent that
461+
return x == 0.0f ? std::numeric_limits<float>::lowest() : std::log(x);
462+
}
463+
464464
static std::vector<unsigned char> str_to_bytes(const std::string & str) {
465465
std::vector<unsigned char> bytes;
466466
for (unsigned char c : str) {

examples/server/tests/unit/test_chat_completion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def test_logprobs():
185185
assert res.choices[0].logprobs.content is not None
186186
for token in res.choices[0].logprobs.content:
187187
aggregated_text += token.token
188-
assert 0.0 <= token.logprob <= 1.0
188+
assert token.logprob <= 0.0
189189
assert token.bytes is not None and len(token.bytes) > 0
190190
assert len(token.top_logprobs) > 0
191191
assert aggregated_text == output_text
@@ -218,7 +218,7 @@ def test_logprobs_stream():
218218
assert choice.logprobs.content is not None
219219
for token in choice.logprobs.content:
220220
aggregated_text += token.token
221-
assert 0.0 <= token.logprob <= 1.0
221+
assert token.logprob <= 0.0
222222
assert token.bytes is not None and len(token.bytes) > 0
223223
assert token.top_logprobs is not None
224224
assert len(token.top_logprobs) > 0

examples/server/tests/unit/test_completion.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,13 +262,13 @@ def test_n_probs():
262262
for tok in res.body["completion_probabilities"]:
263263
assert "id" in tok and tok["id"] > 0
264264
assert "token" in tok and type(tok["token"]) == str
265-
assert "logprob" in tok and 0.0 <= tok["logprob"] <= 1.0
265+
assert "logprob" in tok and tok["logprob"] <= 0.0
266266
assert "bytes" in tok and len(tok["bytes"]) > 0
267267
assert len(tok["top_logprobs"]) == 10
268268
for prob in tok["top_logprobs"]:
269269
assert "id" in prob and prob["id"] > 0
270270
assert "token" in prob and type(prob["token"]) == str
271-
assert "logprob" in prob and 0.0 <= prob["logprob"] <= 1.0
271+
assert "logprob" in prob and prob["logprob"] <= 0.0
272272
assert "bytes" in prob and len(prob["bytes"]) > 0
273273

274274

@@ -289,11 +289,11 @@ def test_n_probs_stream():
289289
for tok in data["completion_probabilities"]:
290290
assert "id" in tok and tok["id"] > 0
291291
assert "token" in tok and type(tok["token"]) == str
292-
assert "logprob" in tok and 0.0 <= tok["logprob"] <= 1.0
292+
assert "logprob" in tok and tok["logprob"] <= 0.0
293293
assert "bytes" in tok and len(tok["bytes"]) > 0
294294
assert len(tok["top_logprobs"]) == 10
295295
for prob in tok["top_logprobs"]:
296296
assert "id" in prob and prob["id"] > 0
297297
assert "token" in prob and type(prob["token"]) == str
298-
assert "logprob" in prob and 0.0 <= prob["logprob"] <= 1.0
298+
assert "logprob" in prob and prob["logprob"] <= 0.0
299299
assert "bytes" in prob and len(prob["bytes"]) > 0

0 commit comments

Comments
 (0)