Skip to content

Commit 905620c

Browse files
authored
[Chat] Support chat completion config override (#2412)
This PR supports chat CLI with arguments override. Right now, arguments supported are: `top_p`, `temperature`, `presence_penalty`, `frequency_penalty`, `max_tokens`, `seed`, `stop`. This PR adds the corresponding support to the ChatCompletion request parsing for JSONFFIEngine.
1 parent 7eba612 commit 905620c

File tree

3 files changed

+96
-12
lines changed

3 files changed

+96
-12
lines changed

cpp/json_ffi/json_ffi_engine.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request
9191
gen_cfg->logprobs = request.logprobs;
9292
gen_cfg->top_logprobs = request.top_logprobs;
9393
gen_cfg->logit_bias = request.logit_bias.value_or(default_gen_cfg->logit_bias);
94-
gen_cfg->seed = request.seed.value_or(default_gen_cfg->seed);
95-
gen_cfg->max_tokens = request.seed.value_or(default_gen_cfg->max_tokens);
94+
gen_cfg->seed = request.seed.value_or(std::random_device{}());
95+
gen_cfg->max_tokens = request.max_tokens.value_or(default_gen_cfg->max_tokens);
9696
gen_cfg->stop_strs = std::move(stop_strs);
9797
gen_cfg->stop_token_ids = conv_template_.stop_token_ids;
9898
gen_cfg->debug_config = request.debug_config.value_or(DebugConfig());

cpp/json_ffi/openai_api_protocol.cc

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,29 +295,66 @@ Result<ChatCompletionRequest> ChatCompletionRequest::FromJSON(const std::string&
295295
}
296296
request.model = model_res.Unwrap();
297297

298+
// temperature
299+
Result<std::optional<double>> temperature_res =
300+
json::LookupOptionalWithResultReturn<double>(json_obj, "temperature");
301+
if (temperature_res.IsErr()) {
302+
return TResult::Error(temperature_res.UnwrapErr());
303+
}
304+
request.temperature = temperature_res.Unwrap();
305+
// top_p
306+
Result<std::optional<double>> top_p_res =
307+
json::LookupOptionalWithResultReturn<double>(json_obj, "top_p");
308+
if (top_p_res.IsErr()) {
309+
return TResult::Error(top_p_res.UnwrapErr());
310+
}
311+
request.top_p = top_p_res.Unwrap();
298312
// max_tokens
299313
Result<std::optional<int64_t>> max_tokens_res =
300314
json::LookupOptionalWithResultReturn<int64_t>(json_obj, "max_tokens");
301315
if (max_tokens_res.IsErr()) {
302316
return TResult::Error(max_tokens_res.UnwrapErr());
303317
}
304318
request.max_tokens = max_tokens_res.Unwrap();
305-
306319
// frequency_penalty
307320
Result<std::optional<double>> frequency_penalty_res =
308321
json::LookupOptionalWithResultReturn<double>(json_obj, "frequency_penalty");
309322
if (frequency_penalty_res.IsErr()) {
310323
return TResult::Error(frequency_penalty_res.UnwrapErr());
311324
}
312325
request.frequency_penalty = frequency_penalty_res.Unwrap();
313-
314326
// presence_penalty
315327
Result<std::optional<double>> presence_penalty_res =
316328
json::LookupOptionalWithResultReturn<double>(json_obj, "presence_penalty");
317329
if (presence_penalty_res.IsErr()) {
318330
return TResult::Error(presence_penalty_res.UnwrapErr());
319331
}
320332
request.presence_penalty = presence_penalty_res.Unwrap();
333+
// seed
334+
Result<std::optional<int64_t>> seed_res =
335+
json::LookupOptionalWithResultReturn<int64_t>(json_obj, "seed");
336+
if (seed_res.IsErr()) {
337+
return TResult::Error(seed_res.UnwrapErr());
338+
}
339+
request.seed = seed_res.Unwrap();
340+
341+
// stop strings
342+
Result<std::optional<picojson::array>> stop_strs_res =
343+
json::LookupOptionalWithResultReturn<picojson::array>(json_obj, "stop");
344+
if (stop_strs_res.IsErr()) {
345+
return TResult::Error(stop_strs_res.UnwrapErr());
346+
}
347+
std::optional<picojson::array> stop_strs = stop_strs_res.Unwrap();
348+
if (stop_strs.has_value()) {
349+
std::vector<std::string> stop;
350+
for (picojson::value stop_str_value : stop_strs.value()) {
351+
if (!stop_str_value.is<std::string>()) {
352+
return TResult::Error("One given value in field \"stop\" is not a string.");
353+
}
354+
stop.push_back(stop_str_value.get<std::string>());
355+
}
356+
request.stop = std::move(stop);
357+
}
321358

322359
// tool_choice
323360
Result<std::string> tool_choice_res =

python/mlc_llm/interface/chat.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,66 @@
11
"""Python entrypoint of chat."""
2-
from typing import List, Optional
2+
3+
import dataclasses
4+
from typing import Dict, List, Optional, Union
35

46
from prompt_toolkit import prompt as get_prompt # pylint: disable=import-error
57
from prompt_toolkit.key_binding import KeyBindings # pylint: disable=import-error
68

79
from mlc_llm.json_ffi import JSONFFIEngine
10+
from mlc_llm.support import argparse
11+
from mlc_llm.support.config import ConfigOverrideBase
12+
13+
14+
@dataclasses.dataclass
15+
class ChatCompletionOverride(ConfigOverrideBase): # pylint: disable=too-many-instance-attributes
16+
"""Flags for overriding chat completions."""
17+
18+
temperature: Optional[float] = None
19+
top_p: Optional[float] = None
20+
frequency_penalty: Optional[float] = None
21+
presence_penalty: Optional[float] = None
22+
max_tokens: Optional[int] = None
23+
seed: Optional[int] = None
24+
stop: Optional[Union[str, List[str]]] = None
25+
26+
@staticmethod
27+
def from_str(source: str) -> "ChatCompletionOverride":
28+
"""Parse model config override values from a string."""
29+
parser = argparse.ArgumentParser(description="chat completion override values")
30+
parser.add_argument("--temperature", type=float, default=None)
31+
parser.add_argument("--top_p", type=float, default=None)
32+
parser.add_argument("--frequency_penalty", type=float, default=None)
33+
parser.add_argument("--presence_penalty", type=float, default=None)
34+
parser.add_argument("--max_tokens", type=int, default=None)
35+
parser.add_argument("--seed", type=int, default=None)
36+
parser.add_argument("--stop", type=str, default=None)
37+
results = parser.parse_args([f"--{i}" for i in source.split(";") if i])
38+
return ChatCompletionOverride(
39+
temperature=results.temperature,
40+
top_p=results.top_p,
41+
frequency_penalty=results.frequency_penalty,
42+
presence_penalty=results.presence_penalty,
43+
max_tokens=results.max_tokens,
44+
seed=results.seed,
45+
stop=results.stop.split(",") if results.stop is not None else None,
46+
)
847

948

1049
class ChatState:
1150
"""Helper class to manage chat state"""
1251

13-
history: List[dict]
52+
history: List[Dict]
1453
history_begin: int
1554
# kwargs passed to completions
16-
overrides: dict
55+
overrides: ChatCompletionOverride
1756
# we use JSON ffi engine to ensure broader coverage
1857
engine: JSONFFIEngine
1958

2059
def __init__(self, engine):
2160
self.engine = engine
2261
self.history = []
2362
self.history_window_begin = 0
24-
self.overrides = {}
63+
self.overrides = ChatCompletionOverride()
2564

2665
def process_system_prompts(self):
2766
"""Process system prompts"""
@@ -45,7 +84,9 @@ def generate(self, prompt: str):
4584
finish_reason_length = False
4685
messages = self.history[self.history_window_begin :]
4786
for response in self.engine.chat.completions.create(
48-
messages=messages, stream=True, **self.overrides
87+
messages=messages,
88+
stream=True,
89+
**dataclasses.asdict(self.overrides),
4990
):
5091
for choice in response.choices:
5192
assert choice.delta.role == "assistant"
@@ -90,6 +131,9 @@ def _print_help_str():
90131
/stats print out stats of last request (token/sec)
91132
/metrics print out full engine metrics
92133
/reset restart a fresh chat
134+
/set [overrides] override settings in the generation config. For example,
135+
`/set temperature=0.5;top_p=0.8;seed=23;max_tokens=100;stop=str1,str2`
136+
Note: Separate stop words in the `stop` option with commas (,).
93137
Multi-line input: Use escape+enter to start a new line.
94138
"""
95139
print(help_str)
@@ -132,16 +176,19 @@ def chat(
132176
key_bindings=kb,
133177
multiline=True,
134178
)
135-
if prompt[:6] == "/stats":
179+
if prompt[:4] == "/set":
180+
overrides = ChatCompletionOverride.from_str(prompt.split()[1])
181+
for key, value in dataclasses.asdict(overrides).items():
182+
if value is not None:
183+
setattr(chat_state.overrides, key, value)
184+
elif prompt[:6] == "/stats":
136185
print(chat_state.stats(), flush=True)
137186
elif prompt[:8] == "/metrics":
138187
print(chat_state.metrics(), flush=True)
139188
elif prompt[:6] == "/reset":
140189
chat_state.reset_chat()
141190
elif prompt[:5] == "/exit":
142191
break
143-
# elif prompt[:6] == "/stats":
144-
# print(cm.stats(), flush=True)
145192
elif prompt[:5] == "/help":
146193
_print_help_str()
147194
else:

0 commit comments

Comments
 (0)