Skip to content

Commit 0281500

Browse files
committed
Start adding Databricks support
1 parent 041af68 commit 0281500

File tree

6 files changed

+107
-10
lines changed

6 files changed

+107
-10
lines changed

deploy_ai_search/.env

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ OpenAI__Endpoint=<openAIEndpoint>
1919
OpenAI__EmbeddingModel=<openAIEmbeddingModelName>
2020
OpenAI__EmbeddingDeployment=<openAIEmbeddingDeploymentId>
2121
OpenAI__EmbeddingDimensions=1536
22-
Text2Sql__DatabaseName=<databaseName>
22+
Text2Sql__DatabaseEngine=<databaseEngine SQL Server / Snowflake / Databricks >

deploy_ai_search/text_2_sql_schema_store.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@
2424
from environment import (
2525
IndexerType,
2626
)
27+
import os
28+
from enum import StrEnum
29+
30+
31+
class DatabaseEngine(StrEnum):
32+
"""An enumeration to represent a database engine."""
33+
34+
SNOWFLAKE = "SNOWFLAKE"
35+
SQL_SERVER = "SQL_SERVER"
36+
DATABRICKS = "DATABRICKS"
2737

2838

2939
class Text2SqlSchemaStoreAISearch(AISearch):
@@ -42,13 +52,34 @@ def __init__(
4252
rebuild (bool, optional): Whether to rebuild the index. Defaults to False.
4353
"""
4454
self.indexer_type = IndexerType.TEXT_2_SQL_SCHEMA_STORE
55+
self.database_engine = DatabaseEngine[
56+
os.environ["Text2Sql__DatabaseEngine"].upper()
57+
]
4558
super().__init__(suffix, rebuild)
4659

4760
if single_data_dictionary:
4861
self.parsing_mode = BlobIndexerParsingMode.JSON_ARRAY
4962
else:
5063
self.parsing_mode = BlobIndexerParsingMode.JSON
5164

65+
@property
66+
def excluded_fields_for_database_engine(self):
67+
"""A method to get the excluded fields for the database engine."""
68+
69+
all_engine_specific_fields = ["Warehouse", "Database", "Catalog"]
70+
if self.database_engine == DatabaseEngine.SNOWFLAKE:
71+
engine_specific_fields = ["Warehouse", "Database"]
72+
elif self.database_engine == DatabaseEngine.SQL_SERVER:
73+
engine_specific_fields = ["Database"]
74+
elif self.database_engine == DatabaseEngine.DATABRICKS:
75+
engine_specific_fields = ["Catalog"]
76+
77+
return [
78+
field
79+
for field in all_engine_specific_fields
80+
if field not in engine_specific_fields
81+
]
82+
5283
def get_index_fields(self) -> list[SearchableField]:
5384
"""This function returns the index fields for sql index.
5485
@@ -78,6 +109,10 @@ def get_index_fields(self) -> list[SearchableField]:
78109
name="Warehouse",
79110
type=SearchFieldDataType.String,
80111
),
112+
SearchableField(
113+
name="Catalog",
114+
type=SearchFieldDataType.String,
115+
),
81116
SearchableField(
82117
name="Definition",
83118
type=SearchFieldDataType.String,
@@ -161,6 +196,13 @@ def get_index_fields(self) -> list[SearchableField]:
161196
),
162197
]
163198

199+
# Remove fields that are not supported by the database engine
200+
fields = [
201+
field
202+
for field in fields
203+
if field.name not in self.excluded_fields_for_database_engine
204+
]
205+
164206
return fields
165207

166208
def get_semantic_search(self) -> SemanticSearch:
@@ -309,4 +351,12 @@ def get_indexer(self) -> SearchIndexer:
309351
parameters=indexer_parameters,
310352
)
311353

354+
# Remove fields that are not supported by the database engine
355+
indexer.output_field_mappings = [
356+
field_mapping
357+
for field_mapping in indexer.output_field_mappings
358+
if field_mapping.target_field_name
359+
not in self.excluded_fields_for_database_engine
360+
]
361+
312362
return indexer

text_2_sql/data_dictionary/.env

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@ OpenAI__EmbeddingModel=<openAIEmbeddingModelName>
33
OpenAI__Endpoint=<openAIEndpoint>
44
OpenAI__ApiKey=<openAIKey if using non managed identity>
55
OpenAI__ApiVersion=<openAIApiVersion>
6-
Text2Sql__DatabaseEngine=<databaseEngine>
76
Text2Sql__DatabaseName=<databaseName>
87
Text2Sql__DatabaseConnectionString=<databaseConnectionString>
98
Text2Sql__Snowflake__User=<snowflakeUser if using Snowflake Data Source>
109
Text2Sql__Snowflake__Password=<snowflakePassword if using Snowflake Data Source>
1110
Text2Sql__Snowflake__Account=<snowflakeAccount if using Snowflake Data Source>
1211
Text2Sql__Snowflake__Warehouse=<snowflakeWarehouse if using Snowflake Data Source>
12+
Text2Sql__Databricks__Catalog=<databricksCatalog if using Databricks Data Source with Unity Catalog>
1313
IdentityType=<identityType> # system_assigned or user_assigned or key
14-
ClientId=<clientId if using user assigned identity>
14+
ClientId=<clientId if using user assigned identity>

text_2_sql/data_dictionary/data_dictionary_creator.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,19 @@
1515
import random
1616
import re
1717
import networkx as nx
18+
from enum import StrEnum
1819

1920
logging.basicConfig(level=logging.INFO)
2021

2122

23+
class DatabaseEngine(StrEnum):
24+
"""An enumeration to represent a database engine."""
25+
26+
SNOWFLAKE = "SNOWFLAKE"
27+
SQL_SERVER = "SQL_SERVER"
28+
DATABRICKS = "DATABRICKS"
29+
30+
2231
class ForeignKeyRelationship(BaseModel):
2332
column: str = Field(..., alias="Column")
2433
foreign_column: str = Field(..., alias="ForeignColumn")
@@ -124,6 +133,7 @@ class EntityItem(BaseModel):
124133
entity_name: Optional[str] = Field(default=None, alias="EntityName")
125134
database: Optional[str] = Field(default=None, alias="Database")
126135
warehouse: Optional[str] = Field(default=None, alias="Warehouse")
136+
catalog: Optional[str] = Field(default=None, alias="Catalog")
127137

128138
entity_relationships: Optional[list[EntityRelationship]] = Field(
129139
alias="EntityRelationships", default_factory=list
@@ -186,6 +196,9 @@ def __init__(
186196

187197
self.warehouse = None
188198
self.database = None
199+
self.catalog = None
200+
201+
self.database_engine = None
189202

190203
load_dotenv(find_dotenv())
191204

@@ -391,6 +404,7 @@ async def extract_entities_with_definitions(self) -> list[EntityItem]:
391404
for entity in all_entities:
392405
entity.warehouse = self.warehouse
393406
entity.database = self.database
407+
entity.catalog = self.catalog
394408

395409
return all_entities
396410

@@ -636,6 +650,24 @@ async def build_entity_entry(self, entity: EntityItem) -> EntityItem:
636650

637651
return entity
638652

653+
@property
654+
def excluded_fields_for_database_engine(self):
655+
"""A method to get the excluded fields for the database engine."""
656+
657+
all_engine_specific_fields = ["Warehouse", "Database", "Catalog"]
658+
if self.database_engine == DatabaseEngine.SNOWFLAKE:
659+
engine_specific_fields = ["Warehouse", "Database"]
660+
elif self.database_engine == DatabaseEngine.SQL_SERVER:
661+
engine_specific_fields = ["Database"]
662+
elif self.database_engine == DatabaseEngine.DATABRICKS:
663+
engine_specific_fields = ["Catalog"]
664+
665+
return [
666+
field
667+
for field in all_engine_specific_fields
668+
if field not in engine_specific_fields
669+
]
670+
639671
async def create_data_dictionary(self):
640672
"""A method to build a data dictionary from a database. Writes to file."""
641673
entities = await self.extract_entities_with_definitions()
@@ -655,12 +687,23 @@ async def create_data_dictionary(self):
655687
logging.info("Saving data dictionary to entities.json")
656688
with open("entities.json", "w", encoding="utf-8") as f:
657689
json.dump(
658-
data_dictionary.model_dump(by_alias=True), f, indent=4, default=str
690+
data_dictionary.model_dump(
691+
by_alias=True, exclude=self.excluded_fields_for_database_engine
692+
),
693+
f,
694+
indent=4,
695+
default=str,
659696
)
660697
else:
661698
for entity in data_dictionary:
662699
logging.info(f"Saving data dictionary for {entity.entity}")
663700
with open(f"{entity.entity}.json", "w", encoding="utf-8") as f:
664701
json.dump(
665-
entity.model_dump(by_alias=True), f, indent=4, default=str
702+
entity.model_dump(
703+
by_alias=True,
704+
exclude=self.excluded_fields_for_database_engine,
705+
),
706+
f,
707+
indent=4,
708+
default=str,
666709
)

text_2_sql/data_dictionary/snowflake_data_dictionary_creator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3-
from data_dictionary_creator import DataDictionaryCreator, EntityItem
3+
from data_dictionary_creator import DataDictionaryCreator, EntityItem, DatabaseEngine
44
import asyncio
55
import snowflake.connector
66
import logging
@@ -25,9 +25,11 @@ def __init__(
2525
excluded_entities = []
2626

2727
excluded_schemas = ["INFORMATION_SCHEMA"]
28-
return super().__init__(
29-
entities, excluded_entities, excluded_schemas, single_file
30-
)
28+
super().__init__(entities, excluded_entities, excluded_schemas, single_file)
29+
30+
self.database = os.environ["Text2Sql__DatabaseName"]
31+
self.warehouse = os.environ["Text2Sql__Snowflake__Warehouse"]
32+
self.database_engine = DatabaseEngine.SNOWFLAKE
3133

3234
"""A class to extract data dictionary information from a Snowflake database."""
3335

text_2_sql/data_dictionary/sql_sever_data_dictionary_creator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3-
from data_dictionary_creator import DataDictionaryCreator, EntityItem
3+
from data_dictionary_creator import DataDictionaryCreator, EntityItem, DatabaseEngine
44
import asyncio
55
import os
66

@@ -26,6 +26,8 @@ def __init__(
2626
super().__init__(entities, excluded_entities, excluded_schemas, single_file)
2727
self.database = os.environ["Text2Sql__DatabaseName"]
2828

29+
self.database_engine = DatabaseEngine.SQL_SERVER
30+
2931
"""A class to extract data dictionary information from a SQL Server database."""
3032

3133
@property

0 commit comments

Comments
 (0)