Skip to content

Commit c4cbe88

Browse files
Merge pull request #158 from MervinPraison/develop
Adding Search with Tavily Feature to Chat and Code
2 parents 285e6c5 + 720921f commit c4cbe88

File tree

8 files changed

+441
-144
lines changed

8 files changed

+441
-144
lines changed

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
FROM python:3.11-slim
22
WORKDIR /app
33
COPY . .
4-
RUN pip install flask praisonai==0.0.70 gunicorn markdown
4+
RUN pip install flask praisonai==0.0.71 gunicorn markdown
55
EXPOSE 8080
66
CMD ["gunicorn", "-b", "0.0.0.0:8080", "api:app"]

docs/api/praisonai/deploy.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ <h2 id="raises">Raises</h2>
110110
file.write(&#34;FROM python:3.11-slim\n&#34;)
111111
file.write(&#34;WORKDIR /app\n&#34;)
112112
file.write(&#34;COPY . .\n&#34;)
113-
file.write(&#34;RUN pip install flask praisonai==0.0.70 gunicorn markdown\n&#34;)
113+
file.write(&#34;RUN pip install flask praisonai==0.0.71 gunicorn markdown\n&#34;)
114114
file.write(&#34;EXPOSE 8080\n&#34;)
115115
file.write(&#39;CMD [&#34;gunicorn&#34;, &#34;-b&#34;, &#34;0.0.0.0:8080&#34;, &#34;api:app&#34;]\n&#39;)
116116

poetry.lock

Lines changed: 119 additions & 103 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

praisonai.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ class Praisonai < Formula
33

44
desc "AI tools for various AI applications"
55
homepage "https://github.com/MervinPraison/PraisonAI"
6-
url "https://github.com/MervinPraison/PraisonAI/archive/refs/tags/0.0.70.tar.gz"
6+
url "https://github.com/MervinPraison/PraisonAI/archive/refs/tags/0.0.71.tar.gz"
77
sha256 "1828fb9227d10f991522c3f24f061943a254b667196b40b1a3e4a54a8d30ce32" # Replace with actual SHA256 checksum
88
license "MIT"
99

praisonai/deploy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def create_dockerfile(self):
5656
file.write("FROM python:3.11-slim\n")
5757
file.write("WORKDIR /app\n")
5858
file.write("COPY . .\n")
59-
file.write("RUN pip install flask praisonai==0.0.70 gunicorn markdown\n")
59+
file.write("RUN pip install flask praisonai==0.0.71 gunicorn markdown\n")
6060
file.write("EXPOSE 8080\n")
6161
file.write('CMD ["gunicorn", "-b", "0.0.0.0:8080", "api:app"]\n')
6262

praisonai/ui/chat.py

Lines changed: 164 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import chainlit as cl
22
from chainlit.input_widget import TextInput
3-
from chainlit.types import ThreadDict
3+
from chainlit.types import ThreadDict # Change this import
44
from litellm import acompletion
55
import os
66
import sqlite3
@@ -14,6 +14,7 @@
1414
import logging
1515
import json
1616
from sql_alchemy import SQLAlchemyDataLayer
17+
from tavily import TavilyClient
1718

1819
# Set up logging
1920
logger = logging.getLogger(__name__)
@@ -171,6 +172,41 @@ def load_setting(key: str) -> str:
171172

172173
cl_data._data_layer = SQLAlchemyDataLayer(conninfo=f"sqlite+aiosqlite:///{DB_PATH}")
173174

175+
# Set Tavily API key
176+
tavily_api_key = os.getenv("TAVILY_API_KEY")
177+
tavily_client = TavilyClient(api_key=tavily_api_key) if tavily_api_key else None
178+
179+
# Function to call Tavily Search API
180+
def tavily_web_search(query):
181+
if not tavily_client:
182+
return json.dumps({
183+
"query": query,
184+
"error": "Tavily API key is not set. Web search is unavailable."
185+
})
186+
response = tavily_client.search(query)
187+
print(response) # Print the full response
188+
return json.dumps({
189+
"query": query,
190+
"answer": response.get('answer'),
191+
"top_result": response['results'][0]['content'] if response['results'] else 'No results found'
192+
})
193+
194+
# Define the tool for function calling
195+
tools = [{
196+
"type": "function",
197+
"function": {
198+
"name": "tavily_web_search",
199+
"description": "Search the web using Tavily API",
200+
"parameters": {
201+
"type": "object",
202+
"properties": {
203+
"query": {"type": "string", "description": "Search query"}
204+
},
205+
"required": ["query"]
206+
}
207+
}
208+
}] if tavily_api_key else []
209+
174210
@cl.on_chat_start
175211
async def start():
176212
initialize_db()
@@ -224,31 +260,130 @@ async def setup_agent(settings):
224260
async def main(message: cl.Message):
225261
model_name = load_setting("model_name") or os.getenv("MODEL_NAME") or "gpt-4o-mini"
226262
message_history = cl.user_session.get("message_history", [])
227-
message_history.append({"role": "user", "content": message.content})
263+
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
264+
265+
# Add the current date and time to the user's message
266+
user_message = f"""
267+
Answer the question and use tools if needed:\n{message.content}.\n\n
268+
Current Date and Time: {now}
269+
"""
270+
message_history.append({"role": "user", "content": user_message})
228271

229272
msg = cl.Message(content="")
230273
await msg.send()
231274

232-
response = await acompletion(
233-
model=model_name,
234-
messages=message_history,
235-
stream=True,
236-
# temperature=0.7,
237-
# max_tokens=500,
238-
# top_p=1
239-
)
275+
# Prepare the completion parameters
276+
completion_params = {
277+
"model": model_name,
278+
"messages": message_history,
279+
"stream": True,
280+
}
281+
282+
# Only add tools and tool_choice if Tavily API key is available
283+
if tavily_api_key:
284+
completion_params["tools"] = tools
285+
completion_params["tool_choice"] = "auto"
286+
287+
response = await acompletion(**completion_params)
240288

241289
full_response = ""
290+
tool_calls = []
291+
current_tool_call = None
292+
242293
async for part in response:
243-
if token := part['choices'][0]['delta']['content']:
244-
await msg.stream_token(token)
245-
full_response += token
294+
if 'choices' in part and len(part['choices']) > 0:
295+
delta = part['choices'][0].get('delta', {})
296+
297+
if 'content' in delta and delta['content'] is not None:
298+
token = delta['content']
299+
await msg.stream_token(token)
300+
full_response += token
301+
302+
if tavily_api_key and 'tool_calls' in delta and delta['tool_calls'] is not None:
303+
for tool_call in delta['tool_calls']:
304+
if current_tool_call is None or tool_call.index != current_tool_call['index']:
305+
if current_tool_call:
306+
tool_calls.append(current_tool_call)
307+
current_tool_call = {
308+
'id': tool_call.id,
309+
'type': tool_call.type,
310+
'index': tool_call.index,
311+
'function': {
312+
'name': tool_call.function.name if tool_call.function else None,
313+
'arguments': ''
314+
}
315+
}
316+
if tool_call.function:
317+
if tool_call.function.name:
318+
current_tool_call['function']['name'] = tool_call.function.name
319+
if tool_call.function.arguments:
320+
current_tool_call['function']['arguments'] += tool_call.function.arguments
321+
322+
if current_tool_call:
323+
tool_calls.append(current_tool_call)
324+
246325
logger.debug(f"Full response: {full_response}")
326+
logger.debug(f"Tool calls: {tool_calls}")
247327
message_history.append({"role": "assistant", "content": full_response})
248328
logger.debug(f"Message history: {message_history}")
249329
cl.user_session.set("message_history", message_history)
250330
await msg.update()
251331

332+
if tavily_api_key and tool_calls:
333+
available_functions = {
334+
"tavily_web_search": tavily_web_search,
335+
}
336+
messages = message_history + [{"role": "assistant", "content": None, "function_call": {
337+
"name": tool_calls[0]['function']['name'],
338+
"arguments": tool_calls[0]['function']['arguments']
339+
}}]
340+
341+
for tool_call in tool_calls:
342+
function_name = tool_call['function']['name']
343+
if function_name in available_functions:
344+
function_to_call = available_functions[function_name]
345+
function_args = tool_call['function']['arguments']
346+
if function_args:
347+
try:
348+
function_args = json.loads(function_args)
349+
function_response = function_to_call(
350+
query=function_args.get("query"),
351+
)
352+
messages.append(
353+
{
354+
"role": "function",
355+
"name": function_name,
356+
"content": function_response,
357+
}
358+
)
359+
except json.JSONDecodeError:
360+
logger.error(f"Failed to parse function arguments: {function_args}")
361+
362+
second_response = await acompletion(
363+
model=model_name,
364+
stream=True,
365+
messages=messages,
366+
)
367+
logger.debug(f"Second LLM response: {second_response}")
368+
369+
# Handle the streaming response
370+
full_response = ""
371+
async for part in second_response:
372+
if 'choices' in part and len(part['choices']) > 0:
373+
delta = part['choices'][0].get('delta', {})
374+
if 'content' in delta and delta['content'] is not None:
375+
token = delta['content']
376+
await msg.stream_token(token)
377+
full_response += token
378+
379+
# Update the message content
380+
msg.content = full_response
381+
await msg.update()
382+
else:
383+
# If no tool calls or Tavily API key is not set, the full_response is already set
384+
msg.content = full_response
385+
await msg.update()
386+
252387
username = os.getenv("CHAINLIT_USERNAME", "admin") # Default to "admin" if not found
253388
password = os.getenv("CHAINLIT_PASSWORD", "admin") # Default to "admin" if not found
254389

@@ -267,7 +402,7 @@ async def send_count():
267402
).send()
268403

269404
@cl.on_chat_resume
270-
async def on_chat_resume(thread: cl_data.ThreadDict):
405+
async def on_chat_resume(thread: ThreadDict): # Change the type hint here
271406
logger.info(f"Resuming chat: {thread['id']}")
272407
model_name = load_setting("model_name") or os.getenv("MODEL_NAME") or "gpt-4o-mini"
273408
logger.debug(f"Model name: {model_name}")
@@ -285,8 +420,14 @@ async def on_chat_resume(thread: cl_data.ThreadDict):
285420
thread_id = thread["id"]
286421
cl.user_session.set("thread_id", thread["id"])
287422

288-
# The metadata should now already be a dictionary
423+
# Ensure metadata is a dictionary
289424
metadata = thread.get("metadata", {})
425+
if isinstance(metadata, str):
426+
try:
427+
metadata = json.loads(metadata)
428+
except json.JSONDecodeError:
429+
metadata = {}
430+
290431
cl.user_session.set("metadata", metadata)
291432

292433
message_history = cl.user_session.get("message_history", [])
@@ -298,7 +439,14 @@ async def on_chat_resume(thread: cl_data.ThreadDict):
298439
message_history.append({"role": "user", "content": message.get("output", "")})
299440
elif msg_type == "assistant_message":
300441
message_history.append({"role": "assistant", "content": message.get("output", "")})
442+
elif msg_type == "run":
443+
# Handle 'run' type messages
444+
if message.get("isError"):
445+
message_history.append({"role": "system", "content": f"Error: {message.get('output', '')}"})
446+
else:
447+
# You might want to handle non-error 'run' messages differently
448+
pass
301449
else:
302-
logger.warning(f"Message without type: {message}")
450+
logger.warning(f"Message without recognized type: {message}")
303451

304452
cl.user_session.set("message_history", message_history)

0 commit comments

Comments
 (0)