1- from django .db import connection
21from typing import Dict , Any , Callable , List
32from dataclasses import dataclass
3+
4+ from django .db import connection
5+
46from .database import ask_database , get_database_info
57
68database_schema_dict = get_database_info (connection )
1113 ]
1214)
1315
16+
1417@dataclass
1518class ToolFunction :
1619 name : str
1720 func : Callable
1821 description : str
1922 parameters : Dict [str , Any ]
2023
24+
2125def create_tool_dict (tool : ToolFunction ) -> Dict [str , Any ]:
2226 return {
2327 "type" : "function" ,
@@ -28,10 +32,11 @@ def create_tool_dict(tool: ToolFunction) -> Dict[str, Any]:
2832 "type" : "object" ,
2933 "properties" : tool .parameters ,
3034 "required" : list (tool .parameters .keys ()),
31- }
32- }
35+ },
36+ },
3337 }
3438
39+
3540TOOL_FUNCTIONS = [
3641 ToolFunction (
3742 name = "ask_database" ,
@@ -56,60 +61,58 @@ def create_tool_dict(tool: ToolFunction) -> Dict[str, Any]:
5661 SQL should be written using this database schema:
5762 { database_schema_string }
5863 The query should be returned in plain text, not in JSON.
59- """
64+ """ ,
6065 }
61- }
66+ },
6267 ),
6368]
6469
6570# Automatically generate the tool_functions dictionary and tools list
66- tool_functions : Dict [str , Callable ] = {
67- tool . name : tool . func for tool in TOOL_FUNCTIONS
68- }
71+ tool_functions : Dict [str , Callable ] = {tool . name : tool . func for tool in TOOL_FUNCTIONS }
72+
73+ tools : List [ Dict [ str , Any ]] = [ create_tool_dict ( tool ) for tool in TOOL_FUNCTIONS ]
6974
70- tools : List [Dict [str , Any ]] = [
71- create_tool_dict (tool ) for tool in TOOL_FUNCTIONS
72- ]
7375
7476def validate_tool_inputs (tool_function_name , tool_arguments ):
7577 """Validate the inputs for the execute_tool function."""
7678 if not isinstance (tool_function_name , str ) or not tool_function_name :
7779 raise ValueError ("Invalid tool function name" )
78-
80+
7981 if not isinstance (tool_arguments , dict ):
8082 raise ValueError ("Tool arguments must be a dictionary" )
81-
83+
8284 # Check if the tool_function_name exists in the tools
8385 tool = next ((t for t in tools if t ["function" ]["name" ] == tool_function_name ), None )
8486 if not tool :
8587 raise ValueError (f"Tool function '{ tool_function_name } ' does not exist" )
86-
88+
8789 # Validate the tool arguments based on the tool's parameters
8890 parameters = tool ["function" ].get ("parameters" , {})
8991 required_params = parameters .get ("required" , [])
9092 for param in required_params :
9193 if param not in tool_arguments :
9294 raise ValueError (f"Missing required parameter: { param } " )
93-
95+
9496 # Check if the parameter types match the expected types
9597 properties = parameters .get ("properties" , {})
9698 for param , prop in properties .items ():
97- expected_type = prop .get (' type' )
99+ expected_type = prop .get (" type" )
98100 if param in tool_arguments :
99- if expected_type == ' string' and not isinstance (tool_arguments [param ], str ):
101+ if expected_type == " string" and not isinstance (tool_arguments [param ], str ):
100102 raise ValueError (f"Parameter '{ param } ' must be of type string" )
101-
103+
104+
102105def execute_tool (function_name : str , arguments : Dict [str , Any ]) -> str :
103106 """
104107 Execute the appropriate function based on the function name.
105-
108+
106109 :param function_name: The name of the function to execute
107110 :param arguments: A dictionary of arguments to pass to the function
108111 :return: The result of the function execution
109112 """
110113 # Validate tool inputs
111114 validate_tool_inputs (function_name , arguments )
112-
115+
113116 try :
114117 return tool_functions [function_name ](** arguments )
115118 except Exception as e :
0 commit comments