Skip to content

Commit 3ad3a5e

Browse files
committed
Update agent logic
1 parent 3ea1ffe commit 3ad3a5e

File tree

9 files changed

+89
-17
lines changed

9 files changed

+89
-17
lines changed

text_2_sql/autogen/agentic_text_2_sql.ipynb renamed to text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb

File renamed without changes.

text_2_sql/autogen/agents/llm_agents/answer_agent.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
model:
2-
gpt-4o-mini
2+
4o-mini
33
description:
44
"An agent that takes the final results from the SQL query and writes the answer to the user's question"
55
system_message:

text_2_sql/autogen/agents/llm_agents/question_decomposition_agent.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
model:
2-
gpt-4o-mini
2+
4o-mini
33
description:
44
"An agent that will decompose the user's question into smaller parts to be used in the SQL queries. Use this agent when the user's question is too complex to be answered in one SQL query. Only use if the user's question is too complex to be answered in one SQL query."
55
system_message:

text_2_sql/autogen/agents/llm_agents/sql_query_correction_agent.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
model:
2-
gpt-4o-mini
2+
4o-mini
33
description:
44
"An agent that will look at the SQL query, SQL query results and correct any mistakes in the SQL query to ensure the correct results are returned. Use this agent AFTER the SQL query has been executed and the results are not as expected."
55
system_message:

text_2_sql/autogen/agents/llm_agents/sql_query_generation_agent.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
model:
2-
gpt-4o-mini
2+
4o-mini
33
description:
44
"An agent that can generate SQL queries once given the schema and the user's question. It will run the SQL query to fetch the results. Use this agent after the SQL Schema Selection Agent has selected the correct schema."
55
system_message:

text_2_sql/autogen/agents/llm_agents/sql_schema_selection_agent.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
model:
2-
gpt-4o-mini
2+
4o-mini
33
description:
44
"An agent that can take a user's question and extract the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term.
55

text_2_sql/autogen/utils/llm_agent_creator.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,56 @@
44
from autogen_core.components.tools import FunctionTool
55
from autogen_agentchat.agents import AssistantAgent
66
from utils.sql import query_execution, get_entity_schemas, query_validation
7-
from utils.models import MINI_MODEL
7+
from utils.models import GPT_4O_MINI_MODEL, GPT_4O_MODEL
88
from jinja2 import Template
99
from datetime import datetime
10+
from autogen_ext.models import AzureOpenAIChatCompletionClient
1011

1112

1213
class LLMAgentCreator:
1314
@classmethod
14-
def load_agent_file(cls, name):
15+
def load_agent_file(cls, name: str) -> dict:
16+
"""Loads the agent file based on the agent name.
17+
18+
Args:
19+
----
20+
name (str): The name of the agent to load.
21+
22+
Returns:
23+
-------
24+
dict: The agent file."""
1525
with open(f"./agents/llm_agents/{name.lower()}.yaml", "r") as file:
1626
file = yaml.safe_load(file)
1727

1828
return file
1929

2030
@classmethod
21-
def get_model(cls, model_name):
22-
if model_name == "gpt-4o-mini":
23-
return MINI_MODEL
31+
def get_model(cls, model_name: str) -> AzureOpenAIChatCompletionClient:
32+
"""Retrieves the model based on the model name.
33+
34+
Args:
35+
----
36+
model_name (str): The name of the model to retrieve.
37+
38+
Returns:
39+
AzureOpenAIChatCompletionClient: The model client."""
40+
if model_name == "4o-mini":
41+
return GPT_4O_MINI_MODEL
42+
elif model_name == "4o":
43+
return GPT_4O_MODEL
2444
else:
2545
raise ValueError(f"Model {model_name} not found")
2646

2747
@classmethod
28-
def get_tool(cls, tool_name):
48+
def get_tool(cls, tool_name: str) -> FunctionTool:
49+
"""Retrieves the tool based on the tool name.
50+
51+
Args:
52+
----
53+
tool_name (str): The name of the tool to retrieve.
54+
55+
Returns:
56+
FunctionTool: The tool."""
2957
if tool_name == "sql_query_execution_tool":
3058
return FunctionTool(
3159
query_execution,
@@ -50,15 +78,38 @@ def get_tool(cls, tool_name):
5078
raise ValueError(f"Tool {tool_name} not found")
5179

5280
@classmethod
53-
def get_property_and_render_parameters(cls, agent_file, property, parameters):
81+
def get_property_and_render_parameters(
82+
cls, agent_file: dict, property: str, parameters: dict
83+
) -> str:
84+
"""Gets the property from the agent file and renders the parameters.
85+
86+
Args:
87+
----
88+
agent_file (dict): The agent file.
89+
property (str): The property to retrieve.
90+
parameters (dict): The parameters to render.
91+
92+
Returns:
93+
-------
94+
str: The rendered property."""
5495
unrendered_parameters = agent_file[property]
5596

5697
rendered_template = Template(unrendered_parameters).render(parameters)
5798

5899
return rendered_template
59100

60101
@classmethod
61-
def create(cls, name: str, **kwargs):
102+
def create(cls, name: str, **kwargs) -> AssistantAgent:
103+
"""Creates an assistant agent based on the agent name.
104+
105+
Args:
106+
----
107+
name (str): The name of the agent to create.
108+
**kwargs: The parameters to render.
109+
110+
Returns:
111+
-------
112+
AssistantAgent: The assistant agent."""
62113
agent_file = cls.load_agent_file(name)
63114

64115
tools = []

text_2_sql/autogen/utils/models.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
1414
# )
1515

16-
MINI_MODEL = AzureOpenAIChatCompletionClient(
16+
GPT_4O_MINI_MODEL = AzureOpenAIChatCompletionClient(
1717
azure_deployment=os.environ["OpenAI__MiniCompletionDeployment"],
1818
model=os.environ["OpenAI__MiniCompletionDeployment"],
1919
api_version="2024-08-01-preview",
@@ -27,3 +27,18 @@
2727
"json_output": True,
2828
},
2929
)
30+
31+
GPT_4O_MODEL = AzureOpenAIChatCompletionClient(
32+
azure_deployment=os.environ["OpenAI__CompletionDeployment"],
33+
model=os.environ["OpenAI__CompletionDeployment"],
34+
api_version="2024-08-01-preview",
35+
azure_endpoint=os.environ["OpenAI__Endpoint"],
36+
# # Optional if you choose key-based authentication.
37+
# azure_ad_token_provider=token_provider,
38+
api_key=os.environ["OpenAI__ApiKey"], # For key-based authentication.
39+
model_capabilities={
40+
"vision": False,
41+
"function_calling": True,
42+
"json_output": True,
43+
},
44+
)

text_2_sql/autogen/utils/sql.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ async def fetch_queries_from_cache(question: str) -> str:
138138
)
139139

140140
if len(cached_schemas) == 0:
141-
return {"cached_questions_and_schemas": None}
141+
return {"contains_pre_run_results": False, "cached_questions_and_schemas": None}
142142

143143
logging.info("Cached schemas: %s", cached_schemas)
144144
if PRE_RUN_QUERY_CACHE and len(cached_schemas) > 0:
@@ -165,6 +165,12 @@ async def fetch_queries_from_cache(question: str) -> str:
165165
"schemas": sql_query["Schemas"],
166166
}
167167

168-
return {"cached_questions_and_schemas": query_result_store}
168+
return {
169+
"contains_pre_run_results": True,
170+
"cached_questions_and_schemas": query_result_store,
171+
}
169172

170-
return {"cached_questions_and_schemas": cached_schemas}
173+
return {
174+
"contains_pre_run_results": False,
175+
"cached_questions_and_schemas": cached_schemas,
176+
}

0 commit comments

Comments
 (0)