Skip to content

Commit 196e237

Browse files
committed
add --multi-token-probs
1 parent 06bb38e commit 196e237

File tree

6 files changed

+22
-0
lines changed

6 files changed

+22
-0
lines changed

common/arg.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
10571057
params.sampling.grammar = json_schema_to_grammar(json::parse(value));
10581058
}
10591059
).set_sparam());
1060+
add_opt(common_arg(
1061+
{"-mtp", "--multi-token-probs"},
1062+
string_format(
1063+
"allow getting probabilities for multiple tokens. note: this will slow down the generation speed (default: %s)",
1064+
params.sampling.multi_token_probs ? "enabled" : "disabled"
1065+
),
1066+
[](common_params & params) {
1067+
params.sampling.multi_token_probs = true;
1068+
}
1069+
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MULTI_TOKEN_PROBS"));
10601070
add_opt(common_arg(
10611071
{"--pooling"}, "{none,mean,cls,last,rank}",
10621072
"pooling type for embeddings, use model default if unspecified",

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ struct common_params_sampling {
134134
bool ignore_eos = false;
135135
bool no_perf = false; // disable performance metrics
136136
bool timing_per_token = false;
137+
bool multi_token_probs = false; // output probabilities for multiple tokens (when n_probs > 0)
137138

138139
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
139140

examples/server/server.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,10 @@ struct server_task {
239239
params.speculative.n_min = std::max(params.speculative.n_min, 2);
240240
params.speculative.n_max = std::max(params.speculative.n_max, 0);
241241

242+
if (!params_base.sampling.multi_token_probs && params.n_predict > 1 && params.sampling.n_probs > 0) {
243+
throw std::runtime_error("For performance reason, n_probs with n_predict > 1 is not allowed. To enable this, start the server with --multi-token-probs");
244+
}
245+
242246
if (params.sampling.dry_base < 1.0f) {
243247
params.sampling.dry_base = defaults.sampling.dry_base;
244248
}

examples/server/tests/unit/test_chat_completion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def test_chat_completion_with_timings_per_token():
166166

167167
def test_logprobs():
168168
global server
169+
server.multi_token_probs = True
169170
server.start()
170171
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
171172
res = client.chat.completions.create(
@@ -193,6 +194,7 @@ def test_logprobs():
193194

194195
def test_logprobs_stream():
195196
global server
197+
server.multi_token_probs = True
196198
server.start()
197199
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
198200
res = client.chat.completions.create(

examples/server/tests/unit/test_completion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def check_slots_status():
249249

250250
def test_n_probs():
251251
global server
252+
server.multi_token_probs = True
252253
server.start()
253254
res = server.make_request("POST", "/completion", data={
254255
"prompt": "I believe the meaning of life is",
@@ -274,6 +275,7 @@ def test_n_probs():
274275

275276
def test_n_probs_stream():
276277
global server
278+
server.multi_token_probs = True
277279
server.start()
278280
res = server.make_stream_request("POST", "/completion", data={
279281
"prompt": "I believe the meaning of life is",

examples/server/tests/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class ServerProcess:
7373
draft_min: int | None = None
7474
draft_max: int | None = None
7575
no_webui: bool | None = None
76+
multi_token_probs: bool | None = None
7677

7778
# session variables
7879
process: subprocess.Popen | None = None
@@ -161,6 +162,8 @@ def start(self, timeout_seconds: int = 10) -> None:
161162
server_args.extend(["--draft-min", self.draft_min])
162163
if self.no_webui:
163164
server_args.append("--no-webui")
165+
if self.multi_token_probs:
166+
server_args.append("--multi-token-probs")
164167

165168
args = [str(arg) for arg in [server_path, *server_args]]
166169
print(f"bench: starting server with: {' '.join(args)}")

0 commit comments

Comments
 (0)