Skip to content

Commit 525ba1c

Browse files
committed
Merge branch 'feature/structured-outputs' into fix/model-flow-errors
2 parents db9031e + 6ee8e5a commit 525ba1c

File tree

11 files changed

+108
-12
lines changed

11 files changed

+108
-12
lines changed

text_2_sql/.env.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Text2Sql__DatabaseEngine=<DatabaseEngine> # TSQL or Postgres or Snowflake or Dat
55
Text2Sql__UseQueryCache=<Determines if the Query Cache will be used to speed up query generation. Defaults to True.> # True or False
66
Text2Sql__PreRunQueryCache=<Determines if the results from the Query Cache will be pre-run to speed up answer generation. Defaults to True.> # True or False
77
Text2Sql__UseColumnValueStore=<Determines if the Column Value Store will be used for schema selection Defaults to True.> # True or False
8+
Text2Sql__GenerateFollowUpQuestions=<Determines if follow up questions will be generated. Defaults to True.> # True or False
89

910
# Open AI Connection Details
1011
OpenAI__CompletionDeployment=<openAICompletionDeploymentId. Used for data dictionary creator>

text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
from autogen_text_2_sql.creators.llm_model_creator import LLMModelCreator
88
from jinja2 import Template
99
import logging
10+
from text_2_sql_core.structured_outputs import (
11+
AnswerAgentWithFollowUpQuestionsAgentOutput,
12+
UserMessageRewriteAgentOutput,
13+
)
1014

1115

1216
class LLMAgentCreator:
@@ -106,10 +110,20 @@ def create(cls, name: str, **kwargs) -> AssistantAgent:
106110
for tool in agent_file["tools"]:
107111
tools.append(cls.get_tool(sql_helper, tool))
108112

113+
structured_output = None
114+
if agent_file.get("structured_output", False):
115+
# Import the structured output agent
116+
if name == "answer_agent_with_follow_up_questions":
117+
structured_output = AnswerAgentWithFollowUpQuestionsAgentOutput
118+
elif name == "user_message_rewrite_agent":
119+
structured_output = UserMessageRewriteAgentOutput
120+
109121
agent = AssistantAgent(
110122
name=name,
111123
tools=tools,
112-
model_client=LLMModelCreator.get_model(agent_file["model"]),
124+
model_client=LLMModelCreator.get_model(
125+
agent_file["model"], structured_output=structured_output
126+
),
113127
description=cls.get_property_and_render_parameters(
114128
agent_file, "description", kwargs
115129
),

text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_model_creator.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212

1313
class LLMModelCreator:
1414
@classmethod
15-
def get_model(cls, model_name: str) -> AzureOpenAIChatCompletionClient:
15+
def get_model(
16+
cls, model_name: str, structured_output=None
17+
) -> AzureOpenAIChatCompletionClient:
1618
"""Retrieves the model based on the model name.
1719
1820
Args:
@@ -22,9 +24,9 @@ def get_model(cls, model_name: str) -> AzureOpenAIChatCompletionClient:
2224
Returns:
2325
AzureOpenAIChatCompletionClient: The model client."""
2426
if model_name == "4o-mini":
25-
return cls.gpt_4o_mini_model()
27+
return cls.gpt_4o_mini_model(structured_output=structured_output)
2628
elif model_name == "4o":
27-
return cls.gpt_4o_model()
29+
return cls.gpt_4o_model(structured_output=structured_output)
2830
else:
2931
raise ValueError(f"Model {model_name} not found")
3032

@@ -46,7 +48,9 @@ def get_authentication_properties(cls) -> dict:
4648
return token_provider, api_key
4749

4850
@classmethod
49-
def gpt_4o_mini_model(cls) -> AzureOpenAIChatCompletionClient:
51+
def gpt_4o_mini_model(
52+
cls, structured_output=None
53+
) -> AzureOpenAIChatCompletionClient:
5054
token_provider, api_key = cls.get_authentication_properties()
5155
return AzureOpenAIChatCompletionClient(
5256
azure_deployment=os.environ["OpenAI__MiniCompletionDeployment"],
@@ -61,10 +65,11 @@ def gpt_4o_mini_model(cls) -> AzureOpenAIChatCompletionClient:
6165
"json_output": True,
6266
},
6367
temperature=0,
68+
response_format=structured_output,
6469
)
6570

6671
@classmethod
67-
def gpt_4o_model(cls) -> AzureOpenAIChatCompletionClient:
72+
def gpt_4o_model(cls, structured_output=None) -> AzureOpenAIChatCompletionClient:
6873
token_provider, api_key = cls.get_authentication_properties()
6974
return AzureOpenAIChatCompletionClient(
7075
azure_deployment=os.environ["OpenAI__CompletionDeployment"],
@@ -79,4 +84,5 @@ def gpt_4o_model(cls) -> AzureOpenAIChatCompletionClient:
7984
"json_output": True,
8085
},
8186
temperature=0,
87+
response_format=structured_output,
8288
)

text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ class Source(InteractionPayloadBase):
8181
answer: str
8282
steps: list[list[str]] = Field(default_factory=list, alias="Steps")
8383
sources: list[Source] = Field(default_factory=list)
84+
follow_up_questions: list[str] | None = Field(
85+
default=None, alias="followUpQuestions"
86+
)
87+
assistant_state: dict | None = Field(default=None, alias="assistantState")
8488

8589
payload_type: Literal[PayloadType.ANSWER_WITH_SOURCES] = Field(
8690
PayloadType.ANSWER_WITH_SOURCES, alias="payloadType"
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
model: "4o-mini"
2+
description: "An agent that generates a response to a user's question."
3+
system_message: |
4+
<role_and_objective>
5+
You are Senior Data Analystm, specializing in providing data driven answers to a user's question. Use the general business use case of '{{ use_case }}' to aid understanding of the user's question. You should provide a clear and concise response based on the information obtained from the SQL queries and their results. Adopt a data-driven approach to generate the response.
6+
</role_and_objective>
7+
8+
<system_information>
9+
You are part of an overall system that provides Text2SQL and subsequent data analysis functionality only. You will be passed a result from multiple SQL queries, you must formulate a response to the user's question using this information.
10+
You can assume that the SQL queries are correct and that the results are accurate.
11+
You and the wider system can only generate SQL queries and process the results of these queries. You cannot access any external resources.
12+
The main ability of the system is to perform natural language understanding and generate SQL queries from the user's question. These queries are then automatically run against the database and the results are passed to you.
13+
</system_information>
14+
15+
<instructions>
16+
17+
Use the information obtained to generate a response to the user's question. The question has been broken down into a series of SQL queries and you need to generate a response based on the results of these queries.
18+
19+
Do not use any external resources to generate the response. The response should be based solely on the information provided in the SQL queries and their results.
20+
21+
You have no access to the internet or any other external resources. You can only use the information provided in the SQL queries and their results, to generate the response.
22+
23+
You can use Markdown and Markdown tables to format the response. You MUST use the information obtained from the SQL queries to generate the response.
24+
25+
If the user is asking about your capabilities, use the <system_information> to explain what you do.
26+
27+
Make sure your response directly addresses every part of the user's question.
28+
29+
Finally, generate 3 data driven follow-up questions based on the information obtained from the SQL queries and their results. Think carefully about what questions may arise from the data and how they can be used to further analyze the data.
30+
31+
</instructions>
32+
33+
<output_
34+
structured_output: true

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_schema_selection_agent.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,4 @@ system_message: |
9696
<key_relationships>
9797
{{ relationship_paths }}
9898
</key_relationships>
99+
structured_output: true

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/user_message_rewrite_agent.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,4 @@ system_message: |
166166
}
167167
```
168168
</examples>
169+
structured_output: true
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from text_2_sql_core.structured_outputs.sql_schema_selection_agent import (
4+
SQLSchemaSelectionAgentOutput,
5+
)
6+
from text_2_sql_core.structured_outputs.user_message_rewrite_agent import (
7+
UserMessageRewriteAgentOutput,
8+
)
9+
from text_2_sql_core.structured_outputs.answer_agent_with_follow_up_questions import (
10+
AnswerAgentWithFollowUpQuestionsAgentOutput,
11+
)
12+
13+
__all__ = [
14+
"AnswerAgentWithFollowUpQuestionsAgentOutput",
15+
"SQLSchemaSelectionAgentOutput",
16+
"UserMessageRewriteAgentOutput",
17+
]
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from pydantic import BaseModel
4+
5+
6+
class AnswerAgentWithFollowUpQuestionsAgentOutput(BaseModel):
7+
answer: str
8+
follow_up_questions: list[str]
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from pydantic import BaseModel
4+
5+
6+
class UserMessageRewriteAgentOutput(BaseModel):
7+
decomposed_user_messages: list[list[str]]
8+
combination_logic: str
9+
query_type: str
10+
all_non_database_query: bool

0 commit comments

Comments
 (0)