Skip to content

Commit 75d962b

Browse files
committed
Updating databricks connector
1 parent 0281500 commit 75d962b

File tree

4 files changed

+153
-1
lines changed

4 files changed

+153
-1
lines changed

text_2_sql/data_dictionary/.env

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,8 @@ Text2Sql__Snowflake__Password=<snowflakePassword if using Snowflake Data Source>
1010
Text2Sql__Snowflake__Account=<snowflakeAccount if using Snowflake Data Source>
1111
Text2Sql__Snowflake__Warehouse=<snowflakeWarehouse if using Snowflake Data Source>
1212
Text2Sql__Databricks__Catalog=<databricksCatalog if using Databricks Data Source with Unity Catalog>
13+
Text2Sql__Databricks__ServerHostname=<databricksServerHostname if using Databricks Data Source with Unity Catalog>
14+
Text2Sql__Databricks__HttpPath=<databricksHttpPath if using Databricks Data Source with Unity Catalog>
15+
Text2Sql__Databricks__AccessToken=<databricks AccessToken if using Databricks Data Source with Unity Catalog>
1316
IdentityType=<identityType> # system_assigned or user_assigned or key
1417
ClientId=<clientId if using user assigned identity>

text_2_sql/data_dictionary/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ See `./generated_samples/` for an example output of the script. This can then be
9999

100100
The following Databases have pre-built scripts for them:
101101

102-
- **Microsoft SQL Server:** `sql_server_data_dictionary_creator.py`
102+
- **Databricks:** `databricks_data_dictionary_creator.py`
103103
- **Snowflake:** `snowflake_data_dictionary_creator.py`
104+
- **SQL Server:** `sql_server_data_dictionary_creator.py`
104105

105106
If there is no pre-built script for your database engine, take one of the above as a starting point and adjust it.
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from data_dictionary_creator import DataDictionaryCreator, EntityItem, DatabaseEngine
4+
import asyncio
5+
from databricks import sql
6+
import logging
7+
import os
8+
9+
10+
class SnowflakeDataDictionaryCreator(DataDictionaryCreator):
11+
def __init__(
12+
self,
13+
entities: list[str] = None,
14+
excluded_entities: list[str] = None,
15+
single_file: bool = False,
16+
):
17+
"""A method to initialize the DataDictionaryCreator class.
18+
19+
Args:
20+
entities (list[str], optional): A list of entities to extract. Defaults to None. If None, all entities are extracted.
21+
excluded_entities (list[str], optional): A list of entities to exclude. Defaults to None.
22+
single_file (bool, optional): A flag to indicate if the data dictionary should be saved to a single file. Defaults to False.
23+
"""
24+
if excluded_entities is None:
25+
excluded_entities = []
26+
27+
excluded_schemas = []
28+
super().__init__(entities, excluded_entities, excluded_schemas, single_file)
29+
30+
self.catalog = os.environ["Text2Sql__Databricks__Catalog"]
31+
self.database_engine = DatabaseEngine.DATABRICKS
32+
33+
"""A class to extract data dictionary information from a Snowflake database."""
34+
35+
@property
36+
def extract_table_entities_sql_query(self) -> str:
37+
"""A property to extract table entities from a Snowflake database."""
38+
return f"""SELECT
39+
t.TABLE_NAME AS Entity,
40+
t.TABLE_SCHEMA AS EntitySchema,
41+
t.COMMENT AS Definition
42+
FROM
43+
INFORMATION_SCHEMA.TABLES t
44+
WHERE
45+
t.TABLE_CATALOG = '{self.catalog}'
46+
"""
47+
48+
@property
49+
def extract_view_entities_sql_query(self) -> str:
50+
"""A property to extract view entities from a Snowflake database."""
51+
return """SELECT
52+
v.TABLE_NAME AS Entity,
53+
v.TABLE_SCHEMA AS EntitySchema
54+
NULL AS Definition
55+
FROM
56+
INFORMATION_SCHEMA.VIEWS v
57+
WHERE
58+
v.TABLE_CATALOG = '{self.catalog}'"""
59+
60+
def extract_columns_sql_query(self, entity: EntityItem) -> str:
61+
"""A property to extract column information from a Snowflake database."""
62+
return f"""SELECT
63+
COLUMN_NAME AS Name,
64+
DATA_TYPE AS Type,
65+
COMMENT AS Definition
66+
FROM
67+
INFORMATION_SCHEMA.COLUMNS
68+
WHERE
69+
TABLE_CATALOG = '{self.catalog}'
70+
AND TABLE_SCHEMA = '{entity.entity_schema}'
71+
AND TABLE_NAME = '{entity.name}';"""
72+
73+
@property
74+
def extract_entity_relationships_sql_query(self) -> str:
75+
"""A property to extract entity relationships from a SQL Server database."""
76+
return """SELECT
77+
tc.table_schema AS EntitySchema,
78+
tc.table_name AS Entity,
79+
rc.unique_constraint_schema AS ForeignEntitySchema,
80+
rc.unique_constraint_name AS ForeignEntityConstraint,
81+
rc.constraint_name AS ForeignKeyConstraint
82+
FROM
83+
information_schema.referential_constraints rc
84+
JOIN
85+
information_schema.table_constraints tc
86+
ON rc.constraint_schema = tc.constraint_schema
87+
AND rc.constraint_name = tc.constraint_name
88+
WHERE
89+
tc.constraint_type = 'FOREIGN KEY'
90+
ORDER BY
91+
EntitySchema, Entity, ForeignEntitySchema, ForeignEntityConstraint;
92+
"""
93+
94+
async def query_entities(self, sql_query: str, cast_to: any = None) -> list[dict]:
95+
"""
96+
A method to query a Databricks SQL endpoint for entities.
97+
98+
Args:
99+
sql_query (str): The SQL query to run.
100+
cast_to (any, optional): The class to cast the results to. Defaults to None.
101+
102+
Returns:
103+
list[dict]: The list of entities or processed rows.
104+
"""
105+
logging.info(f"Running query: {sql_query}")
106+
results = []
107+
108+
# Set up connection parameters for Databricks SQL endpoint
109+
connection = sql.connect(
110+
server_hostname=os.environ["Text2Sql__Databricks__ServerHostname"],
111+
http_path=os.environ["Text2Sql__Databricks__HttpPath"],
112+
access_token=os.environ["Text2Sql__Databricks__AccessToken"],
113+
)
114+
115+
try:
116+
# Create a cursor
117+
cursor = connection.cursor()
118+
119+
# Execute the query in a thread-safe manner
120+
await asyncio.to_thread(cursor.execute, sql_query)
121+
122+
# Fetch column names
123+
columns = [col[0] for col in cursor.description]
124+
125+
# Fetch rows
126+
rows = await asyncio.to_thread(cursor.fetchall)
127+
128+
# Process rows
129+
for row in rows:
130+
if cast_to:
131+
results.append(cast_to.from_sql_row(row, columns))
132+
else:
133+
results.append(dict(zip(columns, row)))
134+
135+
except Exception as e:
136+
logging.error(f"Error while executing query: {e}")
137+
raise
138+
finally:
139+
cursor.close()
140+
connection.close()
141+
142+
return results
143+
144+
145+
if __name__ == "__main__":
146+
data_dictionary_creator = SnowflakeDataDictionaryCreator()
147+
asyncio.run(data_dictionary_creator.create_data_dictionary())

text_2_sql/data_dictionary/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ pydantic
55
openai
66
snowflake-connector-python
77
networkx
8+
databricks

0 commit comments

Comments
 (0)