Skip to content

Commit 408e30f

Browse files
Add Basic Engine Specific Instructions into Connectors (#134)
1 parent f5f7137 commit 408e30f

File tree

9 files changed

+47
-17
lines changed

9 files changed

+47
-17
lines changed

text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
"metadata": {},
8585
"outputs": [],
8686
"source": [
87-
"agentic_text_2_sql = AutoGenText2Sql(engine_specific_rules=\"\", use_case=\"Analysing sales data\")"
87+
"agentic_text_2_sql = AutoGenText2Sql(use_case=\"Analysing sales data\")"
8888
]
8989
},
9090
{
@@ -100,9 +100,16 @@
100100
"metadata": {},
101101
"outputs": [],
102102
"source": [
103-
"async for message in agentic_text_2_sql.process_question(QuestionPayload(question=\"What total number of orders in June 2008?\")):\n",
103+
"async for message in agentic_text_2_sql.process_question(QuestionPayload(question=\"What is the total number of sales?\")):\n",
104104
" logging.info(\"Received %s Message from Text2SQL System\", message)"
105105
]
106+
},
107+
{
108+
"cell_type": "code",
109+
"execution_count": null,
110+
"metadata": {},
111+
"outputs": [],
112+
"source": []
106113
}
107114
],
108115
"metadata": {

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,8 @@
3131

3232

3333
class AutoGenText2Sql:
34-
def __init__(self, engine_specific_rules: str, **kwargs: dict):
34+
def __init__(self, **kwargs: dict):
3535
self.target_engine = os.environ["Text2Sql__DatabaseEngine"].upper()
36-
self.engine_specific_rules = engine_specific_rules
3736
self.kwargs = kwargs
3837

3938
def get_all_agents(self):
@@ -45,9 +44,7 @@ def get_all_agents(self):
4544
"question_rewrite_agent", current_datetime=current_datetime
4645
)
4746

48-
self.parallel_query_solving_agent = ParallelQuerySolvingAgent(
49-
engine_specific_rules=self.engine_specific_rules, **self.kwargs
50-
)
47+
self.parallel_query_solving_agent = ParallelQuerySolvingAgent(**self.kwargs)
5148

5249
self.answer_agent = LLMAgentCreator.create("answer_agent")
5350

text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from text_2_sql_core.prompts.load import load
77
from autogen_text_2_sql.creators.llm_model_creator import LLMModelCreator
88
from jinja2 import Template
9+
import logging
910

1011

1112
class LLMAgentCreator:
@@ -89,6 +90,17 @@ def create(cls, name: str, **kwargs) -> AssistantAgent:
8990

9091
sql_helper = ConnectorFactory.get_database_connector()
9192

93+
# Handle engine specific rules
94+
if "engine_specific_rules" not in kwargs:
95+
if sql_helper.engine_specific_rules is not None:
96+
kwargs["engine_specific_fields"] = sql_helper.engine_specific_rules
97+
logging.info(
98+
"Engine specific fields pulled from in-built: %s",
99+
kwargs["engine_specific_fields"],
100+
)
101+
else:
102+
kwargs["engine_specific_fields"] = ""
103+
92104
tools = []
93105
if "tools" in agent_file and len(agent_file["tools"]) > 0:
94106
for tool in agent_file["tools"]:

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,12 @@
2020

2121

2222
class ParallelQuerySolvingAgent(BaseChatAgent):
23-
def __init__(self, engine_specific_rules: str, **kwargs: dict):
23+
def __init__(self, **kwargs: dict):
2424
super().__init__(
2525
"parallel_query_solving_agent",
2626
"An agent that solves each query in parallel.",
2727
)
2828

29-
self.engine_specific_rules = engine_specific_rules
3029
self.kwargs = kwargs
3130

3231
@property
@@ -177,9 +176,7 @@ async def consume_inner_messages_from_agentic_flow(
177176
for question_rewrite in question_rewrites["sub_questions"]:
178177
logging.info(f"Processing sub-query: {question_rewrite}")
179178
# Create an instance of the InnerAutoGenText2Sql class
180-
inner_autogen_text_2_sql = InnerAutoGenText2Sql(
181-
self.engine_specific_rules, **self.kwargs
182-
)
179+
inner_autogen_text_2_sql = InnerAutoGenText2Sql(**self.kwargs)
183180

184181
identifier = ", ".join(question_rewrite)
185182

text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,9 @@ async def on_messages_stream(self, messages, sender=None, config=None):
3838

3939

4040
class InnerAutoGenText2Sql:
41-
def __init__(self, engine_specific_rules: str, **kwargs: dict):
41+
def __init__(self, **kwargs: dict):
4242
self.pre_run_query_cache = False
4343
self.target_engine = os.environ["Text2Sql__DatabaseEngine"].upper()
44-
self.engine_specific_rules = engine_specific_rules
4544
self.kwargs = kwargs
4645
self.set_mode()
4746

@@ -73,21 +72,18 @@ def get_all_agents(self):
7372

7473
self.sql_schema_selection_agent = SqlSchemaSelectionAgent(
7574
target_engine=self.target_engine,
76-
engine_specific_rules=self.engine_specific_rules,
7775
**self.kwargs,
7876
)
7977

8078
self.sql_query_correction_agent = LLMAgentCreator.create(
8179
"sql_query_correction_agent",
8280
target_engine=self.target_engine,
83-
engine_specific_rules=self.engine_specific_rules,
8481
**self.kwargs,
8582
)
8683

8784
self.disambiguation_and_sql_query_generation_agent = LLMAgentCreator.create(
8885
"disambiguation_and_sql_query_generation_agent",
8986
target_engine=self.target_engine,
90-
engine_specific_rules=self.engine_specific_rules,
9187
**self.kwargs,
9288
)
9389
agents = [

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ def __init__(self):
1717

1818
self.database_engine = DatabaseEngine.DATABRICKS
1919

20+
@property
21+
def engine_specific_rules(self) -> str:
22+
"""Get the engine specific rules."""
23+
return
24+
2025
@property
2126
def engine_specific_fields(self) -> list[str]:
2227
"""Get the engine specific fields."""

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ def __init__(self):
1717

1818
self.database_engine = DatabaseEngine.SNOWFLAKE
1919

20+
@property
21+
def engine_specific_rules(self) -> str:
22+
"""Get the engine specific rules."""
23+
return """When an ORDER BY clause is included in the SQL query, always append the ORDER BY clause with 'NULLS LAST' to ensure that NULL values are at the end of the result set. e.g. 'ORDER BY column_name DESC NULLS LAST'."""
24+
2025
@property
2126
def engine_specific_fields(self) -> list[str]:
2227
"""Get the engine specific fields."""

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ def __init__(self):
3131

3232
self.database_engine = None
3333

34+
@property
35+
@abstractmethod
36+
def engine_specific_rules(self) -> str:
37+
"""Get the engine specific rules."""
38+
pass
39+
3440
@property
3541
@abstractmethod
3642
def invalid_identifiers(self) -> list[str]:

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ def __init__(self):
1616

1717
self.database_engine = DatabaseEngine.TSQL
1818

19+
@property
20+
def engine_specific_rules(self) -> str:
21+
"""Get the engine specific rules."""
22+
return """Use TOP X instead of LIMIT X to limit the number of rows returned."""
23+
1924
@property
2025
def engine_specific_fields(self) -> list[str]:
2126
"""Get the engine specific fields."""

0 commit comments

Comments
 (0)