@@ -63,9 +63,7 @@ def completer(text, state):
6363
6464
6565class ChatBot :
66- def __init__ (
67- self , api_key , model , system_message = None , temperature = DEFAULT_TEMPERATURE
68- ):
66+ def __init__ (self , api_key , model , system_message = None , temperature = DEFAULT_TEMPERATURE ):
6967 if not api_key :
7068 raise ValueError ("An API key must be provided to use the Mistral API." )
7169 self .client = MistralClient (api_key = api_key )
@@ -89,15 +87,11 @@ def opening_instructions(self):
8987
9088 def new_chat (self ):
9189 print ("" )
92- print (
93- f"Starting new chat with model: { self .model } , temperature: { self .temperature } "
94- )
90+ print (f"Starting new chat with model: { self .model } , temperature: { self .temperature } " )
9591 print ("" )
9692 self .messages = []
9793 if self .system_message :
98- self .messages .append (
99- ChatMessage (role = "system" , content = self .system_message )
100- )
94+ self .messages .append (ChatMessage (role = "system" , content = self .system_message ))
10195
10296 def switch_model (self , input ):
10397 model = self .get_arguments (input )
@@ -146,13 +140,9 @@ def run_inference(self, content):
146140 self .messages .append (ChatMessage (role = "user" , content = content ))
147141
148142 assistant_response = ""
149- logger .debug (
150- f"Running inference with model: { self .model } , temperature: { self .temperature } "
151- )
143+ logger .debug (f"Running inference with model: { self .model } , temperature: { self .temperature } " )
152144 logger .debug (f"Sending messages: { self .messages } " )
153- for chunk in self .client .chat_stream (
154- model = self .model , temperature = self .temperature , messages = self .messages
155- ):
145+ for chunk in self .client .chat_stream (model = self .model , temperature = self .temperature , messages = self .messages ):
156146 response = chunk .choices [0 ].delta .content
157147 if response is not None :
158148 print (response , end = "" , flush = True )
@@ -161,9 +151,7 @@ def run_inference(self, content):
161151 print ("" , flush = True )
162152
163153 if assistant_response :
164- self .messages .append (
165- ChatMessage (role = "assistant" , content = assistant_response )
166- )
154+ self .messages .append (ChatMessage (role = "assistant" , content = assistant_response ))
167155 logger .debug (f"Current messages: { self .messages } " )
168156
169157 def get_command (self , input ):
@@ -215,9 +203,7 @@ def exit(self):
215203
216204
217205if __name__ == "__main__" :
218- parser = argparse .ArgumentParser (
219- description = "A simple chatbot using the Mistral API"
220- )
206+ parser = argparse .ArgumentParser (description = "A simple chatbot using the Mistral API" )
221207 parser .add_argument (
222208 "--api-key" ,
223209 default = os .environ .get ("MISTRAL_API_KEY" ),
@@ -230,19 +216,15 @@ def exit(self):
230216 default = DEFAULT_MODEL ,
231217 help = "Model for chat inference. Choices are %(choices)s. Defaults to %(default)s" ,
232218 )
233- parser .add_argument (
234- "-s" , "--system-message" , help = "Optional system message to prepend."
235- )
219+ parser .add_argument ("-s" , "--system-message" , help = "Optional system message to prepend." )
236220 parser .add_argument (
237221 "-t" ,
238222 "--temperature" ,
239223 type = float ,
240224 default = DEFAULT_TEMPERATURE ,
241225 help = "Optional temperature for chat inference. Defaults to %(default)s" ,
242226 )
243- parser .add_argument (
244- "-d" , "--debug" , action = "store_true" , help = "Enable debug logging"
245- )
227+ parser .add_argument ("-d" , "--debug" , action = "store_true" , help = "Enable debug logging" )
246228
247229 args = parser .parse_args ()
248230
0 commit comments