Skip to content

Commit 825d2c8

Browse files
committed
Store excluded fields in
1 parent 190f5d2 commit 825d2c8

File tree

9 files changed

+63
-41
lines changed

9 files changed

+63
-41
lines changed

deploy_ai_search/src/deploy_ai_search/text_2_sql_schema_store.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727
import os
2828
from text_2_sql_core.utils.database import DatabaseEngine
29+
from text_2_sql_core.connectors.factory import ConnectorFactory
2930

3031

3132
class Text2SqlSchemaStoreAISearch(AISearch):
@@ -49,31 +50,13 @@ def __init__(
4950
os.environ["Text2Sql__DatabaseEngine"].upper()
5051
]
5152

53+
self.database_connector = ConnectorFactory.get_database_connector()
54+
5255
if single_data_dictionary_file:
5356
self.parsing_mode = BlobIndexerParsingMode.JSON_ARRAY
5457
else:
5558
self.parsing_mode = BlobIndexerParsingMode.JSON
5659

57-
@property
58-
def excluded_fields_for_database_engine(self):
59-
"""A method to get the excluded fields for the database engine."""
60-
61-
all_engine_specific_fields = ["Warehouse", "Database", "Catalog"]
62-
if self.database_engine == DatabaseEngine.SNOWFLAKE:
63-
engine_specific_fields = ["Warehouse", "Database"]
64-
elif self.database_engine == DatabaseEngine.TSQL:
65-
engine_specific_fields = ["Database"]
66-
elif self.database_engine == DatabaseEngine.DATABRICKS:
67-
engine_specific_fields = ["Catalog"]
68-
elif self.database_engine == DatabaseEngine.POSTGRESQL:
69-
engine_specific_fields = ["Database"]
70-
71-
return [
72-
field
73-
for field in all_engine_specific_fields
74-
if field not in engine_specific_fields
75-
]
76-
7760
def get_index_fields(self) -> list[SearchableField]:
7861
"""This function returns the index fields for sql index.
7962
@@ -198,7 +181,7 @@ def get_index_fields(self) -> list[SearchableField]:
198181
fields = [
199182
field
200183
for field in fields
201-
if field.name not in self.excluded_fields_for_database_engine
184+
if field.name not in self.database_connector.excluded_engine_specific_fields
202185
]
203186

204187
return fields

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import logging
99
import json
1010

11-
from text_2_sql_core.utils.database import DatabaseEngine
11+
from text_2_sql_core.utils.database import DatabaseEngine, DatabaseEngineSpecificFields
1212

1313

1414
class DatabricksSqlConnector(SqlConnector):
@@ -17,6 +17,11 @@ def __init__(self):
1717

1818
self.database_engine = DatabaseEngine.DATABRICKS
1919

20+
@property
21+
def engine_specific_fields(self) -> list[str]:
22+
"""Get the engine specific fields."""
23+
return [DatabaseEngineSpecificFields.CATALOG]
24+
2025
async def query_execution(
2126
self,
2227
sql_query: Annotated[

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/factory.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ def get_database_connector():
2525
from text_2_sql_core.connectors.tsql_sql import TSQLSqlConnector
2626

2727
return TSQLSqlConnector()
28+
elif os.environ["Text2Sql__DatabaseEngine"].upper() == "POSTGRESQL":
29+
from text_2_sql_core.connectors.postgresql_sql import (
30+
PostgresqlSqlConnector,
31+
)
32+
33+
return PostgresqlSqlConnector()
2834
else:
2935
raise ValueError(
3036
f"""Database engine {

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/postgresql_sql.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import logging
88
import json
99

10-
from text_2_sql_core.utils.database import DatabaseEngine
10+
from text_2_sql_core.utils.database import DatabaseEngine, DatabaseEngineSpecificFields
1111

1212

1313
class PostgresqlSqlConnector(SqlConnector):
@@ -16,6 +16,11 @@ def __init__(self):
1616

1717
self.database_engine = DatabaseEngine.POSTGRESQL
1818

19+
@property
20+
def engine_specific_fields(self) -> list[str]:
21+
"""Get the engine specific fields."""
22+
return [DatabaseEngineSpecificFields.DATABASE]
23+
1924
async def query_execution(
2025
self,
2126
sql_query: Annotated[str, "The SQL query to run against the database."],

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/snowflake_sql.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import logging
99
import json
1010

11-
from text_2_sql_core.utils.database import DatabaseEngine
11+
from text_2_sql_core.utils.database import DatabaseEngine, DatabaseEngineSpecificFields
1212

1313

1414
class SnowflakeSqlConnector(SqlConnector):
@@ -17,6 +17,14 @@ def __init__(self):
1717

1818
self.database_engine = DatabaseEngine.SNOWFLAKE
1919

20+
@property
21+
def engine_specific_fields(self) -> list[str]:
22+
"""Get the engine specific fields."""
23+
return [
24+
DatabaseEngineSpecificFields.WAREHOUSE,
25+
DatabaseEngineSpecificFields.DATABASE,
26+
]
27+
2028
async def query_execution(
2129
self,
2230
sql_query: Annotated[

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from abc import ABC, abstractmethod
1010
from jinja2 import Template
1111
import json
12+
from text_2_sql_core.utils.database import DatabaseEngineSpecificFields
1213

1314

1415
class SqlConnector(ABC):
@@ -29,6 +30,22 @@ def __init__(self):
2930

3031
self.database_engine = None
3132

33+
@abstractmethod
34+
@property
35+
def engine_specific_fields(self) -> list[str]:
36+
"""Get the engine specific fields."""
37+
pass
38+
39+
@property
40+
def excluded_engine_specific_fields(self):
41+
"""A method to get the excluded fields for the database engine."""
42+
43+
return [
44+
field.value.capitalize()
45+
for field in DatabaseEngineSpecificFields
46+
if field not in self.engine_specific_fields
47+
]
48+
3249
@abstractmethod
3350
async def query_execution(
3451
self,

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import logging
88
import json
99

10-
from text_2_sql_core.utils.database import DatabaseEngine
10+
from text_2_sql_core.utils.database import DatabaseEngine, DatabaseEngineSpecificFields
1111

1212

1313
class TSQLSqlConnector(SqlConnector):
@@ -16,6 +16,11 @@ def __init__(self):
1616

1717
self.database_engine = DatabaseEngine.TSQL
1818

19+
@property
20+
def engine_specific_fields(self) -> list[str]:
21+
"""Get the engine specific fields."""
22+
return [DatabaseEngineSpecificFields.DATABASE]
23+
1924
async def query_execution(
2025
self,
2126
sql_query: Annotated[

text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/data_dictionary_creator.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import random
1212
import re
1313
import networkx as nx
14-
from text_2_sql_core.utils.database import DatabaseEngine
1514
from tenacity import retry, stop_after_attempt, wait_exponential
1615
from text_2_sql_core.connectors.open_ai import OpenAIConnector
1716

@@ -751,23 +750,9 @@ async def build_entity_entry(self, entity: EntityItem) -> EntityItem:
751750
def excluded_fields_for_database_engine(self):
752751
"""A method to get the excluded fields for the database engine."""
753752

754-
all_engine_specific_fields = ["Warehouse", "Database", "Catalog"]
755-
if self.database_engine == DatabaseEngine.SNOWFLAKE:
756-
engine_specific_fields = ["Warehouse", "Database"]
757-
elif self.database_engine == DatabaseEngine.TSQL:
758-
engine_specific_fields = ["Database"]
759-
elif self.database_engine == DatabaseEngine.DATABRICKS:
760-
engine_specific_fields = ["Catalog"]
761-
elif self.database_engine == DatabaseEngine.POSTGRESQL:
762-
engine_specific_fields = ["Database"]
763-
else:
764-
engine_specific_fields = []
765-
766753
# Determine top-level fields to exclude
767754
filtered_entitiy_specific_fields = {
768-
field.lower(): ...
769-
for field in all_engine_specific_fields
770-
if field not in engine_specific_fields
755+
field.lower(): ... for field in self.excluded_engine_specific_fields
771756
}
772757

773758
if filtered_entitiy_specific_fields:

text_2_sql/text_2_sql_core/src/text_2_sql_core/utils/database.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,11 @@ class DatabaseEngine(StrEnum):
88
SNOWFLAKE = "SNOWFLAKE"
99
TSQL = "TSQL"
1010
POSTGRESQL = "POSTGRESQL"
11+
12+
13+
class DatabaseEngineSpecificFields(StrEnum):
14+
"""An enumeration to represent the database engine specific fields."""
15+
16+
WAREHOUSE = "Warehouse"
17+
DATABASE = "Database"
18+
CATALOG = "Catalog"

0 commit comments

Comments
 (0)