1
1
"""Python entrypoint of chat."""
2
- from typing import List , Optional
2
+
3
+ import dataclasses
4
+ from typing import Dict , List , Optional , Union
3
5
4
6
from prompt_toolkit import prompt as get_prompt # pylint: disable=import-error
5
7
from prompt_toolkit .key_binding import KeyBindings # pylint: disable=import-error
6
8
7
9
from mlc_llm .json_ffi import JSONFFIEngine
10
+ from mlc_llm .support import argparse
11
+ from mlc_llm .support .config import ConfigOverrideBase
12
+
13
+
14
+ @dataclasses .dataclass
15
+ class ChatCompletionOverride (ConfigOverrideBase ): # pylint: disable=too-many-instance-attributes
16
+ """Flags for overriding chat completions."""
17
+
18
+ temperature : Optional [float ] = None
19
+ top_p : Optional [float ] = None
20
+ frequency_penalty : Optional [float ] = None
21
+ presence_penalty : Optional [float ] = None
22
+ max_tokens : Optional [int ] = None
23
+ seed : Optional [int ] = None
24
+ stop : Optional [Union [str , List [str ]]] = None
25
+
26
+ @staticmethod
27
+ def from_str (source : str ) -> "ChatCompletionOverride" :
28
+ """Parse model config override values from a string."""
29
+ parser = argparse .ArgumentParser (description = "chat completion override values" )
30
+ parser .add_argument ("--temperature" , type = float , default = None )
31
+ parser .add_argument ("--top_p" , type = float , default = None )
32
+ parser .add_argument ("--frequency_penalty" , type = float , default = None )
33
+ parser .add_argument ("--presence_penalty" , type = float , default = None )
34
+ parser .add_argument ("--max_tokens" , type = int , default = None )
35
+ parser .add_argument ("--seed" , type = int , default = None )
36
+ parser .add_argument ("--stop" , type = str , default = None )
37
+ results = parser .parse_args ([f"--{ i } " for i in source .split (";" ) if i ])
38
+ return ChatCompletionOverride (
39
+ temperature = results .temperature ,
40
+ top_p = results .top_p ,
41
+ frequency_penalty = results .frequency_penalty ,
42
+ presence_penalty = results .presence_penalty ,
43
+ max_tokens = results .max_tokens ,
44
+ seed = results .seed ,
45
+ stop = results .stop .split ("," ) if results .stop is not None else None ,
46
+ )
8
47
9
48
10
49
class ChatState :
11
50
"""Helper class to manage chat state"""
12
51
13
- history : List [dict ]
52
+ history : List [Dict ]
14
53
history_begin : int
15
54
# kwargs passed to completions
16
- overrides : dict
55
+ overrides : ChatCompletionOverride
17
56
# we use JSON ffi engine to ensure broader coverage
18
57
engine : JSONFFIEngine
19
58
20
59
def __init__ (self , engine ):
21
60
self .engine = engine
22
61
self .history = []
23
62
self .history_window_begin = 0
24
- self .overrides = {}
63
+ self .overrides = ChatCompletionOverride ()
25
64
26
65
def process_system_prompts (self ):
27
66
"""Process system prompts"""
@@ -45,7 +84,9 @@ def generate(self, prompt: str):
45
84
finish_reason_length = False
46
85
messages = self .history [self .history_window_begin :]
47
86
for response in self .engine .chat .completions .create (
48
- messages = messages , stream = True , ** self .overrides
87
+ messages = messages ,
88
+ stream = True ,
89
+ ** dataclasses .asdict (self .overrides ),
49
90
):
50
91
for choice in response .choices :
51
92
assert choice .delta .role == "assistant"
@@ -90,6 +131,9 @@ def _print_help_str():
90
131
/stats print out stats of last request (token/sec)
91
132
/metrics print out full engine metrics
92
133
/reset restart a fresh chat
134
+ /set [overrides] override settings in the generation config. For example,
135
+ `/set temperature=0.5;top_p=0.8;seed=23;max_tokens=100;stop=str1,str2`
136
+ Note: Separate stop words in the `stop` option with commas (,).
93
137
Multi-line input: Use escape+enter to start a new line.
94
138
"""
95
139
print (help_str )
@@ -132,16 +176,19 @@ def chat(
132
176
key_bindings = kb ,
133
177
multiline = True ,
134
178
)
135
- if prompt [:6 ] == "/stats" :
179
+ if prompt [:4 ] == "/set" :
180
+ overrides = ChatCompletionOverride .from_str (prompt .split ()[1 ])
181
+ for key , value in dataclasses .asdict (overrides ).items ():
182
+ if value is not None :
183
+ setattr (chat_state .overrides , key , value )
184
+ elif prompt [:6 ] == "/stats" :
136
185
print (chat_state .stats (), flush = True )
137
186
elif prompt [:8 ] == "/metrics" :
138
187
print (chat_state .metrics (), flush = True )
139
188
elif prompt [:6 ] == "/reset" :
140
189
chat_state .reset_chat ()
141
190
elif prompt [:5 ] == "/exit" :
142
191
break
143
- # elif prompt[:6] == "/stats":
144
- # print(cm.stats(), flush=True)
145
192
elif prompt [:5 ] == "/help" :
146
193
_print_help_str ()
147
194
else :
0 commit comments