Skip to content

Commit 1678189

Browse files
committed
Update creator scripts
1 parent f6801e1 commit 1678189

File tree

7 files changed

+64
-44
lines changed

7 files changed

+64
-44
lines changed

deploy_ai_search/text_2_sql_query_cache.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,6 @@ def get_index_fields(self) -> list[SearchableField]:
112112
SearchableField(
113113
name="DataType", type=SearchFieldDataType.String
114114
),
115-
SearchableField(
116-
name="AllowedValues",
117-
type=SearchFieldDataType.String,
118-
collection=True,
119-
searchable=False,
120-
),
121115
SearchableField(
122116
name="SampleValues",
123117
type=SearchFieldDataType.String,

deploy_ai_search/text_2_sql_schema_store.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,6 @@ def get_index_fields(self) -> list[SearchableField]:
126126
SearchableField(name="Name", type=SearchFieldDataType.String),
127127
SearchableField(name="Definition", type=SearchFieldDataType.String),
128128
SearchableField(name="DataType", type=SearchFieldDataType.String),
129-
SearchableField(
130-
name="AllowedValues",
131-
type=SearchFieldDataType.String,
132-
collection=True,
133-
searchable=False,
134-
),
135129
SearchableField(
136130
name="SampleValues",
137131
type=SearchFieldDataType.String,

text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def create(
8282
rich_print(detailed_error)
8383

8484
raise typer.Exit(code=1)
85-
85+
asyncio.run(data_dictionary_creator.create_data_dictionary())
8686
try:
8787
asyncio.run(data_dictionary_creator.create_data_dictionary())
8888
except Exception as e:

text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/data_dictionary_creator.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ class ColumnItem(BaseModel):
100100
distinct_values: Optional[list[any]] = Field(
101101
None, alias="DistinctValues", exclude=True
102102
)
103-
allowed_values: Optional[list[any]] = Field(None, alias="AllowedValues")
104103
sample_values: Optional[list[any]] = Field(None, alias="SampleValues")
105104

106105
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
@@ -160,12 +159,12 @@ def id(self):
160159

161160
def value_store_entry(self, excluded_fields_for_database_engine):
162161
excluded_fields = excluded_fields_for_database_engine + [
163-
"Definition",
164-
"Name",
165-
"EntityName",
166-
"EntityRelationships",
167-
"CompleteEntityRelationshipsGraph",
168-
"Columns",
162+
"definition",
163+
"name",
164+
"entity_name",
165+
"entity_relationships",
166+
"complete_entity_relationships_graph",
167+
"columns",
169168
]
170169
return self.model_dump(
171170
by_alias=True, exclude_none=True, exclude=excluded_fields
@@ -226,8 +225,8 @@ def __init__(
226225

227226
self.database_engine = None
228227

229-
self.database_semaphore = asyncio.Semaphore(50)
230-
self.llm_semaphone = asyncio.Semaphore(20)
228+
self.database_semaphore = asyncio.Semaphore(20)
229+
self.llm_semaphone = asyncio.Semaphore(10)
231230

232231
if output_directory is None:
233232
self.output_directory = "."
@@ -274,7 +273,7 @@ def extract_distinct_values_sql_query(
274273
Returns:
275274
str: The SQL query to extract distinct values from a column.
276275
"""
277-
return f"""SELECT DISTINCT {column.name} FROM {entity.entity} ORDER BY {column.name} DESC;"""
276+
return f"""SELECT DISTINCT {column.name} FROM {entity.entity} WHERE {column.name} IS NOT NULL ORDER BY {column.name} DESC;"""
278277

279278
@retry(
280279
stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=10)
@@ -422,13 +421,13 @@ async def extract_entities_with_definitions(self) -> list[EntityItem]:
422421
all_entities = table_entities + view_entities
423422

424423
# Filter entities if entities is not None
425-
if self.entities:
424+
if self.entities is not None:
426425
all_entities = [
427426
entity for entity in all_entities if entity.entity in self.entities
428427
]
429428

430429
# Filter entities if excluded_entities is not None
431-
if self.excluded_entities:
430+
if len(self.excluded_entities) > 0 or len(self.excluded_schemas):
432431
all_entities = [
433432
entity
434433
for entity in all_entities
@@ -448,20 +447,24 @@ async def write_columns_to_file(self, entity: EntityItem, column: ColumnItem):
448447
logging.info(f"Saving column values for {column.name}")
449448

450449
key = f"{entity.id}.{column.name}"
450+
# Ensure the intermediate directories exist
451+
os.makedirs(f"{self.output_directory}/column_value_store", exist_ok=True)
451452
with open(
452453
f"{self.output_directory}/column_value_store/{key}.jsonl",
453454
"w",
454455
encoding="utf-8",
455456
) as f:
456-
for distinct_value in column.distinct_values:
457-
json.dump(
458-
column.value_store_entry(
459-
entity, distinct_value, self.excluded_fields_for_database_engine
460-
),
461-
f,
462-
indent=4,
463-
default=str,
464-
)
457+
if column.distinct_values is not None:
458+
for distinct_value in column.distinct_values:
459+
json_string = json.dumps(
460+
column.value_store_entry(
461+
entity,
462+
distinct_value,
463+
self.excluded_fields_for_database_engine,
464+
),
465+
default=str,
466+
)
467+
f.write(json_string + "\n")
465468

466469
async def extract_column_distinct_values(
467470
self, entity: EntityItem, column: ColumnItem
@@ -498,6 +501,8 @@ async def extract_column_distinct_values(
498501
elif column.distinct_values is not None:
499502
column.sample_values = column.distinct_values
500503

504+
await self.write_columns_to_file(entity, column)
505+
501506
async def generate_column_definition(self, entity: EntityItem, column: ColumnItem):
502507
"""A method to generate a definition for a column in a database.
503508
@@ -682,6 +687,9 @@ async def generate_entity_definition(self, entity: EntityItem):
682687

683688
async def write_entity_to_file(self, entity):
684689
logging.info(f"Saving data dictionary for {entity.entity}")
690+
691+
# Ensure the intermediate directories exist
692+
os.makedirs(f"{self.output_directory}/schema_store", exist_ok=True)
685693
with open(
686694
f"{self.output_directory}/schema_store/{entity.id}.json",
687695
"w",
@@ -742,7 +750,7 @@ def excluded_fields_for_database_engine(self):
742750
engine_specific_fields = ["Catalog"]
743751

744752
return [
745-
field
753+
field.lower()
746754
for field in all_engine_specific_fields
747755
if field not in engine_specific_fields
748756
]
@@ -764,6 +772,8 @@ async def create_data_dictionary(self):
764772
# Save data dictionary to file
765773
if self.single_file:
766774
logging.info("Saving data dictionary to entities.json")
775+
# Ensure the intermediate directories exist
776+
os.makedirs(f"{self.output_directory}/schema_store", exist_ok=True)
767777
with open(
768778
f"{self.output_directory}/schema_store/entities.json",
769779
"w",

text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/databricks_data_dictionary_creator.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3-
from data_dictionary_creator import DataDictionaryCreator, EntityItem
3+
from data_dictionary_creator import DataDictionaryCreator, EntityItem, ColumnItem
44
import asyncio
55
from databricks import sql
66
import logging
@@ -30,6 +30,8 @@ def extract_table_entities_sql_query(self) -> str:
3030
t.COMMENT AS Definition
3131
FROM
3232
{self.catalog}.INFORMATION_SCHEMA.TABLES t
33+
ORDER BY EntitySchema, Entity
34+
LIMIT 1
3335
"""
3436

3537
@property
@@ -41,6 +43,8 @@ def extract_view_entities_sql_query(self) -> str:
4143
NULL AS Definition
4244
FROM
4345
{self.catalog}.INFORMATION_SCHEMA.VIEWS v
46+
ORDER BY EntitySchema, Entity
47+
LIMIT 1
4448
"""
4549

4650
def extract_columns_sql_query(self, entity: EntityItem) -> str:
@@ -94,6 +98,20 @@ def extract_entity_relationships_sql_query(self) -> str:
9498
EntitySchema, Entity, ForeignEntitySchema, ForeignEntity;
9599
"""
96100

101+
def extract_distinct_values_sql_query(
102+
self, entity: EntityItem, column: ColumnItem
103+
) -> str:
104+
"""A method to extract distinct values from a column in a database. Can be sub-classed if needed.
105+
106+
Args:
107+
entity (EntityItem): The entity to extract distinct values from.
108+
column (ColumnItem): The column to extract distinct values from.
109+
110+
Returns:
111+
str: The SQL query to extract distinct values from a column.
112+
"""
113+
return f"""SELECT DISTINCT {column.name} FROM {self.catalog}.{entity.entity} WHERE {column.name} IS NOT NULL ORDER BY {column.name} DESC;"""
114+
97115
@retry(
98116
stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=10)
99117
)
@@ -134,14 +152,14 @@ async def query_entities(self, sql_query: str, cast_to: any = None) -> list[dict
134152

135153
# Process rows
136154
for row in rows:
137-
if cast_to:
155+
if cast_to is not None:
138156
results.append(cast_to.from_sql_row(row, columns))
139157
else:
140158
results.append(dict(zip(columns, row)))
141159

142160
except Exception as e:
143-
logging.error(f"Error while executing query: {e}")
144-
raise
161+
logging.error(f"Error while executing query {sql_query}: {e}")
162+
raise e
145163
finally:
146164
cursor.close()
147165
connection.close()

text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/snowflake_data_dictionary_creator.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def extract_table_entities_sql_query(self) -> str:
3131
t.TABLE_SCHEMA AS EntitySchema,
3232
t.COMMENT AS Definition
3333
FROM
34-
INFORMATION_SCHEMA.TABLES t"""
34+
INFORMATION_SCHEMA.TABLES t
35+
ORDER BY EntitySchema, Entity"""
3536

3637
@property
3738
def extract_view_entities_sql_query(self) -> str:
@@ -41,13 +42,14 @@ def extract_view_entities_sql_query(self) -> str:
4142
v.TABLE_SCHEMA AS EntitySchema,
4243
v.COMMENT AS Definition
4344
FROM
44-
INFORMATION_SCHEMA.VIEWS v"""
45+
INFORMATION_SCHEMA.VIEWS v
46+
ORDER BY EntitySchema, Entity"""
4547

4648
def extract_columns_sql_query(self, entity: EntityItem) -> str:
4749
"""A property to extract column information from a Snowflake database."""
4850
return f"""SELECT
4951
COLUMN_NAME AS Name,
50-
DATA_TYPE AS Type,
52+
DATA_TYPE AS DataType,
5153
COMMENT AS Definition
5254
FROM
5355
INFORMATION_SCHEMA.COLUMNS

text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/tsql_data_dictionary_creator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ def extract_table_entities_sql_query(self) -> str:
3939
AND ep.class = 1
4040
AND ep.name = 'MS_Description'
4141
WHERE
42-
t.TABLE_TYPE = 'BASE TABLE';"""
42+
t.TABLE_TYPE = 'BASE TABLE';
43+
ORDER BY EntitySchema, Entity"""
4344

4445
@property
4546
def extract_view_entities_sql_query(self) -> str:
@@ -55,7 +56,8 @@ def extract_view_entities_sql_query(self) -> str:
5556
ON ep.major_id = OBJECT_ID(v.TABLE_SCHEMA + '.' + v.TABLE_NAME)
5657
AND ep.minor_id = 0
5758
AND ep.class = 1
58-
AND ep.name = 'MS_Description';"""
59+
AND ep.name = 'MS_Description';
60+
ORDER BY EntitySchema, Entity"""
5961

6062
def extract_columns_sql_query(self, entity: EntityItem) -> str:
6163
"""A property to extract column information from a SQL Server database."""

0 commit comments

Comments
 (0)