Skip to content

Commit 4e00504

Browse files
committed
Map engine specific deals
1 parent 6639705 commit 4e00504

File tree

5 files changed

+9
-6
lines changed

5 files changed

+9
-6
lines changed

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from typing import Annotated
1414
from text_2_sql_core.connectors.open_ai import OpenAIConnector
1515

16+
from text_2_sql_core.utils.database import DatabaseEngineSpecificFields
17+
1618

1719
class AISearchConnector:
1820
def __init__(self):
@@ -167,7 +169,7 @@ async def get_entity_schemas(
167169
"The entities to exclude from the search results. Pass the entity property of entities (e.g. 'SalesLT.Address') you already have the schemas for to avoid getting repeated entities.",
168170
] = [],
169171
engine_specific_fields: Annotated[
170-
list[str],
172+
list[DatabaseEngineSpecificFields],
171173
"The fields specific to the engine to be included in the search results.",
172174
] = [],
173175
) -> str:
@@ -192,7 +194,7 @@ async def get_entity_schemas(
192194
"Columns",
193195
"EntityRelationships",
194196
"CompleteEntityRelationshipsGraph",
195-
] + engine_specific_fields
197+
] + list(map(str, engine_specific_fields))
196198

197199
schemas = await self.run_ai_search_query(
198200
text,

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ async def get_entity_schemas(
135135
"""
136136

137137
schemas = await self.ai_search_connector.get_entity_schemas(
138-
text, excluded_entities, engine_specific_fields=["Catalog"]
138+
text, excluded_entities, engine_specific_fields=self.engine_specific_fields
139139
)
140140

141141
for schema in schemas:

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,15 @@ async def get_entity_schemas(
111111
"""
112112

113113
schemas = await self.ai_search_connector.get_entity_schemas(
114-
text, excluded_entities
114+
text, excluded_entities, engine_specific_fields=self.engine_specific_fields
115115
)
116116

117117
for schema in schemas:
118118
schema["SelectFromEntity"] = ".".join([schema["Schema"], schema["Entity"]])
119119

120120
del schema["Entity"]
121121
del schema["Schema"]
122+
del schema["Database"]
122123

123124
if as_json:
124125
return json.dumps(schemas, default=str)

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ async def get_entity_schemas(
150150
"""
151151

152152
schemas = await self.ai_search_connector.get_entity_schemas(
153-
text, excluded_entities, engine_specific_fields=["Warehouse", "Database"]
153+
text, excluded_entities, engine_specific_fields=self.engine_specific_fields
154154
)
155155

156156
for schema in schemas:

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ async def get_entity_schemas(
124124
"""
125125

126126
schemas = await self.ai_search_connector.get_entity_schemas(
127-
text, excluded_entities
127+
text, excluded_entities, engine_specific_fields=self.engine_specific_fields
128128
)
129129

130130
for schema in schemas:

0 commit comments

Comments
 (0)