Skip to content

Commit 65e22c2

Browse files
fix: enable custom tools in chat interface
- Add custom tool loading from tools.py - Extend tools list to include both Tavily and custom tools - Update tool handling to work with any available tools - Fix function calling to handle different parameter signatures - Add proper error handling for sync/async functions - Fix AsyncWebCrawler typo Fixes #620 Co-authored-by: Mervin Praison <[email protected]>
1 parent 53c0b84 commit 65e22c2

File tree

1 file changed

+136
-23
lines changed

1 file changed

+136
-23
lines changed

src/praisonai/praisonai/ui/chat.py

Lines changed: 136 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import asyncio
88
import io
99
import base64
10+
import importlib.util
11+
import inspect
1012

1113
# Third-party imports
1214
from dotenv import load_dotenv
@@ -59,6 +61,76 @@ def load_setting(key: str) -> str:
5961

6062
cl_data._data_layer = db_manager
6163

64+
def load_custom_tools():
65+
"""Load custom tools from tools.py if it exists"""
66+
custom_tools = {}
67+
try:
68+
spec = importlib.util.spec_from_file_location("tools", "tools.py")
69+
if spec is None:
70+
logger.debug("tools.py not found in current directory")
71+
return custom_tools
72+
73+
module = importlib.util.module_from_spec(spec)
74+
spec.loader.exec_module(module)
75+
76+
# Load all functions from tools.py
77+
for name, obj in inspect.getmembers(module):
78+
if not name.startswith('_') and callable(obj) and not inspect.isclass(obj):
79+
# Store function in globals for access
80+
globals()[name] = obj
81+
82+
# Get function signature to build parameters
83+
sig = inspect.signature(obj)
84+
params_properties = {}
85+
required_params = []
86+
87+
for param_name, param in sig.parameters.items():
88+
if param_name != 'self': # Skip self parameter
89+
# Get type annotation if available
90+
param_type = "string" # Default type
91+
if param.annotation != inspect.Parameter.empty:
92+
if param.annotation == int:
93+
param_type = "integer"
94+
elif param.annotation == float:
95+
param_type = "number"
96+
elif param.annotation == bool:
97+
param_type = "boolean"
98+
99+
params_properties[param_name] = {
100+
"type": param_type,
101+
"description": f"Parameter {param_name}"
102+
}
103+
104+
# Add to required if no default value
105+
if param.default == inspect.Parameter.empty:
106+
required_params.append(param_name)
107+
108+
# Build tool definition
109+
tool_def = {
110+
"type": "function",
111+
"function": {
112+
"name": name,
113+
"description": obj.__doc__ or f"Function {name.replace('_', ' ')}",
114+
"parameters": {
115+
"type": "object",
116+
"properties": params_properties,
117+
"required": required_params
118+
}
119+
}
120+
}
121+
122+
custom_tools[name] = tool_def
123+
logger.info(f"Loaded custom tool: {name}")
124+
125+
logger.info(f"Loaded {len(custom_tools)} custom tools from tools.py")
126+
except Exception as e:
127+
logger.warning(f"Error loading custom tools: {e}")
128+
129+
return custom_tools
130+
131+
# Load custom tools
132+
custom_tools_dict = load_custom_tools()
133+
62134
tavily_api_key = os.getenv("TAVILY_API_KEY")
63135
tavily_client = TavilyClient(api_key=tavily_api_key) if tavily_api_key else None
64136

@@ -72,7 +144,7 @@ async def tavily_web_search(query):
72144
response = tavily_client.search(query)
73145
logger.debug(f"Tavily search response: {response}")
74146

75-
async with AsyncAsyncWebCrawler() as crawler:
147+
async with AsyncWebCrawler() as crawler:
76148
results = []
77149
for result in response.get('results', []):
78150
url = result.get('url')
@@ -97,20 +169,28 @@ async def tavily_web_search(query):
97169
"results": results
98170
})
99171

100-
tools = [{
101-
"type": "function",
102-
"function": {
103-
"name": "tavily_web_search",
104-
"description": "Search the web using Tavily API and crawl the resulting URLs",
105-
"parameters": {
106-
"type": "object",
107-
"properties": {
108-
"query": {"type": "string", "description": "Search query"}
109-
},
110-
"required": ["query"]
172+
# Build tools list with Tavily and custom tools
173+
tools = []
174+
175+
# Add Tavily tool if API key is available
176+
if tavily_api_key:
177+
tools.append({
178+
"type": "function",
179+
"function": {
180+
"name": "tavily_web_search",
181+
"description": "Search the web using Tavily API and crawl the resulting URLs",
182+
"parameters": {
183+
"type": "object",
184+
"properties": {
185+
"query": {"type": "string", "description": "Search query"}
186+
},
187+
"required": ["query"]
188+
}
111189
}
112-
}
113-
}] if tavily_api_key else []
190+
})
191+
192+
# Add custom tools from tools.py
193+
tools.extend(list(custom_tools_dict.values()))
114194

115195
# Authentication configuration
116196
AUTH_PASSWORD_ENABLED = os.getenv("AUTH_PASSWORD_ENABLED", "true").lower() == "true" # Password authentication enabled by default
@@ -235,7 +315,8 @@ async def main(message: cl.Message):
235315
]
236316
}
237317

238-
if tavily_api_key:
318+
# Pass tools if we have any (Tavily or custom)
319+
if tools:
239320
completion_params["tools"] = tools
240321
completion_params["tool_choice"] = "auto"
241322

@@ -254,7 +335,7 @@ async def main(message: cl.Message):
254335
await msg.stream_token(token)
255336
full_response += token
256337

257-
if tavily_api_key and 'tool_calls' in delta and delta['tool_calls'] is not None:
338+
if tools and 'tool_calls' in delta and delta['tool_calls'] is not None:
258339
for tool_call in delta['tool_calls']:
259340
if current_tool_call is None or tool_call.index != current_tool_call['index']:
260341
if current_tool_call:
@@ -284,10 +365,17 @@ async def main(message: cl.Message):
284365
cl.user_session.set("message_history", message_history)
285366
await msg.update()
286367

287-
if tavily_api_key and tool_calls:
288-
available_functions = {
289-
"tavily_web_search": tavily_web_search,
290-
}
368+
if tool_calls and tools: # Check if we have any tools and tool calls
369+
available_functions = {}
370+
371+
# Add Tavily function if available
372+
if tavily_api_key:
373+
available_functions["tavily_web_search"] = tavily_web_search
374+
375+
# Add all custom tool functions from globals
376+
for tool_name in custom_tools_dict:
377+
if tool_name in globals():
378+
available_functions[tool_name] = globals()[tool_name]
291379
messages = message_history + [{"role": "assistant", "content": None, "function_call": {
292380
"name": tool_calls[0]['function']['name'],
293381
"arguments": tool_calls[0]['function']['arguments']
@@ -301,9 +389,25 @@ async def main(message: cl.Message):
301389
if function_args:
302390
try:
303391
function_args = json.loads(function_args)
304-
function_response = await function_to_call(
305-
query=function_args.get("query"),
306-
)
392+
393+
# Call function based on whether it's async or sync
394+
if asyncio.iscoroutinefunction(function_to_call):
395+
# For async functions like tavily_web_search
396+
if function_name == "tavily_web_search":
397+
function_response = await function_to_call(
398+
query=function_args.get("query"),
399+
)
400+
else:
401+
# For custom async functions, pass all arguments
402+
function_response = await function_to_call(**function_args)
403+
else:
404+
# For sync functions (most custom tools)
405+
function_response = function_to_call(**function_args)
406+
407+
# Convert response to string if needed
408+
if not isinstance(function_response, str):
409+
function_response = json.dumps(function_response)
410+
307411
messages.append(
308412
{
309413
"role": "function",
@@ -313,6 +417,15 @@ async def main(message: cl.Message):
313417
)
314418
except json.JSONDecodeError:
315419
logger.error(f"Failed to parse function arguments: {function_args}")
420+
except Exception as e:
421+
logger.error(f"Error calling function {function_name}: {str(e)}")
422+
messages.append(
423+
{
424+
"role": "function",
425+
"name": function_name,
426+
"content": f"Error: {str(e)}",
427+
}
428+
)
316429

317430
second_response = await acompletion(
318431
model=model_name,

0 commit comments

Comments
 (0)