11import chainlit as cl
22from chainlit .input_widget import TextInput
3- from chainlit .types import ThreadDict
3+ from chainlit .types import ThreadDict # Change this import
44from litellm import acompletion
55import os
66import sqlite3
1414import logging
1515import json
1616from sql_alchemy import SQLAlchemyDataLayer
17+ from tavily import TavilyClient
1718
1819# Set up logging
1920logger = logging .getLogger (__name__ )
@@ -171,6 +172,41 @@ def load_setting(key: str) -> str:
171172
172173cl_data ._data_layer = SQLAlchemyDataLayer (conninfo = f"sqlite+aiosqlite:///{ DB_PATH } " )
173174
175+ # Set Tavily API key
176+ tavily_api_key = os .getenv ("TAVILY_API_KEY" )
177+ tavily_client = TavilyClient (api_key = tavily_api_key ) if tavily_api_key else None
178+
179+ # Function to call Tavily Search API
180+ def tavily_web_search (query ):
181+ if not tavily_client :
182+ return json .dumps ({
183+ "query" : query ,
184+ "error" : "Tavily API key is not set. Web search is unavailable."
185+ })
186+ response = tavily_client .search (query )
187+ print (response ) # Print the full response
188+ return json .dumps ({
189+ "query" : query ,
190+ "answer" : response .get ('answer' ),
191+ "top_result" : response ['results' ][0 ]['content' ] if response ['results' ] else 'No results found'
192+ })
193+
194+ # Define the tool for function calling
195+ tools = [{
196+ "type" : "function" ,
197+ "function" : {
198+ "name" : "tavily_web_search" ,
199+ "description" : "Search the web using Tavily API" ,
200+ "parameters" : {
201+ "type" : "object" ,
202+ "properties" : {
203+ "query" : {"type" : "string" , "description" : "Search query" }
204+ },
205+ "required" : ["query" ]
206+ }
207+ }
208+ }] if tavily_api_key else []
209+
174210@cl .on_chat_start
175211async def start ():
176212 initialize_db ()
@@ -224,31 +260,130 @@ async def setup_agent(settings):
224260async def main (message : cl .Message ):
225261 model_name = load_setting ("model_name" ) or os .getenv ("MODEL_NAME" ) or "gpt-4o-mini"
226262 message_history = cl .user_session .get ("message_history" , [])
227- message_history .append ({"role" : "user" , "content" : message .content })
263+ now = datetime .now ().strftime ("%Y-%m-%d %H:%M:%S" )
264+
265+ # Add the current date and time to the user's message
266+ user_message = f"""
267+ Answer the question and use tools if needed:\n { message .content } .\n \n
268+ Current Date and Time: { now }
269+ """
270+ message_history .append ({"role" : "user" , "content" : user_message })
228271
229272 msg = cl .Message (content = "" )
230273 await msg .send ()
231274
232- response = await acompletion (
233- model = model_name ,
234- messages = message_history ,
235- stream = True ,
236- # temperature=0.7,
237- # max_tokens=500,
238- # top_p=1
239- )
275+ # Prepare the completion parameters
276+ completion_params = {
277+ "model" : model_name ,
278+ "messages" : message_history ,
279+ "stream" : True ,
280+ }
281+
282+ # Only add tools and tool_choice if Tavily API key is available
283+ if tavily_api_key :
284+ completion_params ["tools" ] = tools
285+ completion_params ["tool_choice" ] = "auto"
286+
287+ response = await acompletion (** completion_params )
240288
241289 full_response = ""
290+ tool_calls = []
291+ current_tool_call = None
292+
242293 async for part in response :
243- if token := part ['choices' ][0 ]['delta' ]['content' ]:
244- await msg .stream_token (token )
245- full_response += token
294+ if 'choices' in part and len (part ['choices' ]) > 0 :
295+ delta = part ['choices' ][0 ].get ('delta' , {})
296+
297+ if 'content' in delta and delta ['content' ] is not None :
298+ token = delta ['content' ]
299+ await msg .stream_token (token )
300+ full_response += token
301+
302+ if tavily_api_key and 'tool_calls' in delta and delta ['tool_calls' ] is not None :
303+ for tool_call in delta ['tool_calls' ]:
304+ if current_tool_call is None or tool_call .index != current_tool_call ['index' ]:
305+ if current_tool_call :
306+ tool_calls .append (current_tool_call )
307+ current_tool_call = {
308+ 'id' : tool_call .id ,
309+ 'type' : tool_call .type ,
310+ 'index' : tool_call .index ,
311+ 'function' : {
312+ 'name' : tool_call .function .name if tool_call .function else None ,
313+ 'arguments' : ''
314+ }
315+ }
316+ if tool_call .function :
317+ if tool_call .function .name :
318+ current_tool_call ['function' ]['name' ] = tool_call .function .name
319+ if tool_call .function .arguments :
320+ current_tool_call ['function' ]['arguments' ] += tool_call .function .arguments
321+
322+ if current_tool_call :
323+ tool_calls .append (current_tool_call )
324+
246325 logger .debug (f"Full response: { full_response } " )
326+ logger .debug (f"Tool calls: { tool_calls } " )
247327 message_history .append ({"role" : "assistant" , "content" : full_response })
248328 logger .debug (f"Message history: { message_history } " )
249329 cl .user_session .set ("message_history" , message_history )
250330 await msg .update ()
251331
332+ if tavily_api_key and tool_calls :
333+ available_functions = {
334+ "tavily_web_search" : tavily_web_search ,
335+ }
336+ messages = message_history + [{"role" : "assistant" , "content" : None , "function_call" : {
337+ "name" : tool_calls [0 ]['function' ]['name' ],
338+ "arguments" : tool_calls [0 ]['function' ]['arguments' ]
339+ }}]
340+
341+ for tool_call in tool_calls :
342+ function_name = tool_call ['function' ]['name' ]
343+ if function_name in available_functions :
344+ function_to_call = available_functions [function_name ]
345+ function_args = tool_call ['function' ]['arguments' ]
346+ if function_args :
347+ try :
348+ function_args = json .loads (function_args )
349+ function_response = function_to_call (
350+ query = function_args .get ("query" ),
351+ )
352+ messages .append (
353+ {
354+ "role" : "function" ,
355+ "name" : function_name ,
356+ "content" : function_response ,
357+ }
358+ )
359+ except json .JSONDecodeError :
360+ logger .error (f"Failed to parse function arguments: { function_args } " )
361+
362+ second_response = await acompletion (
363+ model = model_name ,
364+ stream = True ,
365+ messages = messages ,
366+ )
367+ logger .debug (f"Second LLM response: { second_response } " )
368+
369+ # Handle the streaming response
370+ full_response = ""
371+ async for part in second_response :
372+ if 'choices' in part and len (part ['choices' ]) > 0 :
373+ delta = part ['choices' ][0 ].get ('delta' , {})
374+ if 'content' in delta and delta ['content' ] is not None :
375+ token = delta ['content' ]
376+ await msg .stream_token (token )
377+ full_response += token
378+
379+ # Update the message content
380+ msg .content = full_response
381+ await msg .update ()
382+ else :
383+ # If no tool calls or Tavily API key is not set, the full_response is already set
384+ msg .content = full_response
385+ await msg .update ()
386+
252387username = os .getenv ("CHAINLIT_USERNAME" , "admin" ) # Default to "admin" if not found
253388password = os .getenv ("CHAINLIT_PASSWORD" , "admin" ) # Default to "admin" if not found
254389
@@ -267,7 +402,7 @@ async def send_count():
267402 ).send ()
268403
269404@cl .on_chat_resume
270- async def on_chat_resume (thread : cl_data . ThreadDict ):
405+ async def on_chat_resume (thread : ThreadDict ): # Change the type hint here
271406 logger .info (f"Resuming chat: { thread ['id' ]} " )
272407 model_name = load_setting ("model_name" ) or os .getenv ("MODEL_NAME" ) or "gpt-4o-mini"
273408 logger .debug (f"Model name: { model_name } " )
@@ -285,8 +420,14 @@ async def on_chat_resume(thread: cl_data.ThreadDict):
285420 thread_id = thread ["id" ]
286421 cl .user_session .set ("thread_id" , thread ["id" ])
287422
288- # The metadata should now already be a dictionary
423+ # Ensure metadata is a dictionary
289424 metadata = thread .get ("metadata" , {})
425+ if isinstance (metadata , str ):
426+ try :
427+ metadata = json .loads (metadata )
428+ except json .JSONDecodeError :
429+ metadata = {}
430+
290431 cl .user_session .set ("metadata" , metadata )
291432
292433 message_history = cl .user_session .get ("message_history" , [])
@@ -298,7 +439,14 @@ async def on_chat_resume(thread: cl_data.ThreadDict):
298439 message_history .append ({"role" : "user" , "content" : message .get ("output" , "" )})
299440 elif msg_type == "assistant_message" :
300441 message_history .append ({"role" : "assistant" , "content" : message .get ("output" , "" )})
442+ elif msg_type == "run" :
443+ # Handle 'run' type messages
444+ if message .get ("isError" ):
445+ message_history .append ({"role" : "system" , "content" : f"Error: { message .get ('output' , '' )} " })
446+ else :
447+ # You might want to handle non-error 'run' messages differently
448+ pass
301449 else :
302- logger .warning (f"Message without type: { message } " )
450+ logger .warning (f"Message without recognized type: { message } " )
303451
304452 cl .user_session .set ("message_history" , message_history )
0 commit comments