Skip to content

Commit 6bd7837

Browse files
committed
Update engine specific rules logix
1 parent 4884a7d commit 6bd7837

File tree

4 files changed

+17
-15
lines changed

4 files changed

+17
-15
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,8 @@
3131

3232

3333
class AutoGenText2Sql:
34-
def __init__(self, engine_specific_rules: str, **kwargs: dict):
34+
def __init__(self, **kwargs: dict):
3535
self.target_engine = os.environ["Text2Sql__DatabaseEngine"].upper()
36-
self.engine_specific_rules = engine_specific_rules
3736
self.kwargs = kwargs
3837

3938
def get_all_agents(self):
@@ -45,9 +44,7 @@ def get_all_agents(self):
4544
"question_rewrite_agent", current_datetime=current_datetime
4645
)
4746

48-
self.parallel_query_solving_agent = ParallelQuerySolvingAgent(
49-
engine_specific_rules=self.engine_specific_rules, **self.kwargs
50-
)
47+
self.parallel_query_solving_agent = ParallelQuerySolvingAgent(**self.kwargs)
5148

5249
self.answer_agent = LLMAgentCreator.create("answer_agent")
5350

text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from text_2_sql_core.prompts.load import load
77
from autogen_text_2_sql.creators.llm_model_creator import LLMModelCreator
88
from jinja2 import Template
9+
import logging
910

1011

1112
class LLMAgentCreator:
@@ -89,6 +90,17 @@ def create(cls, name: str, **kwargs) -> AssistantAgent:
8990

9091
sql_helper = ConnectorFactory.get_database_connector()
9192

93+
# Handle engine specific rules
94+
if "engine_specific_rules" not in kwargs:
95+
if sql_helper.engine_specific_rules is not None:
96+
kwargs["engine_specific_fields"] = sql_helper.engine_specific_rules
97+
logging.info(
98+
"Engine specific fields pulled from in-built: %s",
99+
kwargs["engine_specific_fields"],
100+
)
101+
else:
102+
kwargs["engine_specific_fields"] = ""
103+
92104
tools = []
93105
if "tools" in agent_file and len(agent_file["tools"]) > 0:
94106
for tool in agent_file["tools"]:

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,12 @@
2020

2121

2222
class ParallelQuerySolvingAgent(BaseChatAgent):
23-
def __init__(self, engine_specific_rules: str, **kwargs: dict):
23+
def __init__(self, **kwargs: dict):
2424
super().__init__(
2525
"parallel_query_solving_agent",
2626
"An agent that solves each query in parallel.",
2727
)
2828

29-
self.engine_specific_rules = engine_specific_rules
3029
self.kwargs = kwargs
3130

3231
@property
@@ -177,9 +176,7 @@ async def consume_inner_messages_from_agentic_flow(
177176
for question_rewrite in question_rewrites["sub_questions"]:
178177
logging.info(f"Processing sub-query: {question_rewrite}")
179178
# Create an instance of the InnerAutoGenText2Sql class
180-
inner_autogen_text_2_sql = InnerAutoGenText2Sql(
181-
self.engine_specific_rules, **self.kwargs
182-
)
179+
inner_autogen_text_2_sql = InnerAutoGenText2Sql(**self.kwargs)
183180

184181
identifier = ", ".join(question_rewrite)
185182

text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,9 @@ async def on_messages_stream(self, messages, sender=None, config=None):
3838

3939

4040
class InnerAutoGenText2Sql:
41-
def __init__(self, engine_specific_rules: str, **kwargs: dict):
41+
def __init__(self, **kwargs: dict):
4242
self.pre_run_query_cache = False
4343
self.target_engine = os.environ["Text2Sql__DatabaseEngine"].upper()
44-
self.engine_specific_rules = engine_specific_rules
4544
self.kwargs = kwargs
4645
self.set_mode()
4746

@@ -73,21 +72,18 @@ def get_all_agents(self):
7372

7473
self.sql_schema_selection_agent = SqlSchemaSelectionAgent(
7574
target_engine=self.target_engine,
76-
engine_specific_rules=self.engine_specific_rules,
7775
**self.kwargs,
7876
)
7977

8078
self.sql_query_correction_agent = LLMAgentCreator.create(
8179
"sql_query_correction_agent",
8280
target_engine=self.target_engine,
83-
engine_specific_rules=self.engine_specific_rules,
8481
**self.kwargs,
8582
)
8683

8784
self.disambiguation_and_sql_query_generation_agent = LLMAgentCreator.create(
8885
"disambiguation_and_sql_query_generation_agent",
8986
target_engine=self.target_engine,
90-
engine_specific_rules=self.engine_specific_rules,
9187
**self.kwargs,
9288
)
9389
agents = [

0 commit comments

Comments
 (0)