Skip to content

Commit f6801e1

Browse files
committed
Update
1 parent 4b37982 commit f6801e1

File tree

3 files changed

+93
-86
lines changed

3 files changed

+93
-86
lines changed

text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/data_dictionary_creator.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def __init__(
195195
excluded_entities: list[str] = None,
196196
excluded_schemas: list[str] = None,
197197
single_file: bool = False,
198-
generate_definitions: bool = True,
198+
generate_definitions: bool = False,
199199
output_directory: str = None,
200200
):
201201
"""A method to initialize the DataDictionaryCreator class.
@@ -226,6 +226,9 @@ def __init__(
226226

227227
self.database_engine = None
228228

229+
self.database_semaphore = asyncio.Semaphore(50)
230+
self.llm_semaphone = asyncio.Semaphore(20)
231+
229232
if output_directory is None:
230233
self.output_directory = "."
231234

@@ -292,19 +295,20 @@ async def query_entities(
292295

293296
logging.info(f"Running query: {sql_query}")
294297
results = []
295-
async with await aioodbc.connect(dsn=connection_string) as sql_db_client:
296-
async with sql_db_client.cursor() as cursor:
297-
await cursor.execute(sql_query)
298+
async with self.database_semaphore:
299+
async with await aioodbc.connect(dsn=connection_string) as sql_db_client:
300+
async with sql_db_client.cursor() as cursor:
301+
await cursor.execute(sql_query)
298302

299-
columns = [column[0] for column in cursor.description]
303+
columns = [column[0] for column in cursor.description]
300304

301-
rows = await cursor.fetchall()
305+
rows = await cursor.fetchall()
302306

303-
for row in rows:
304-
if cast_to:
305-
results.append(cast_to.from_sql_row(row, columns))
306-
else:
307-
results.append(dict(zip(columns, row)))
307+
for row in rows:
308+
if cast_to:
309+
results.append(cast_to.from_sql_row(row, columns))
310+
else:
311+
results.append(dict(zip(columns, row)))
308312

309313
return results
310314

@@ -428,7 +432,7 @@ async def extract_entities_with_definitions(self) -> list[EntityItem]:
428432
all_entities = [
429433
entity
430434
for entity in all_entities
431-
if entity.entity.lower() not in self.excluded_entities
435+
if entity.name.lower() not in self.excluded_entities
432436
and entity.entity_schema.lower() not in self.excluded_schemas
433437
]
434438

@@ -529,11 +533,12 @@ async def generate_column_definition(self, entity: EntityItem, column: ColumnIte
529533

530534
column_definition_input += existing_definition_string
531535

532-
logging.info(f"Generating definition for {column.name}")
533-
definition = await self.send_request_to_llm(
534-
column_definition_system_prompt, column_definition_input
535-
)
536-
logging.info(f"Definition for {column.name}: {definition}")
536+
async with self.llm_semaphone:
537+
logging.info(f"Generating definition for {column.name}")
538+
definition = await self.send_request_to_llm(
539+
column_definition_system_prompt, column_definition_input
540+
)
541+
logging.info(f"Definition for {column.name}: {definition}")
537542

538543
column.definition = definition
539544

text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/databricks_data_dictionary_creator.py

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -111,41 +111,42 @@ async def query_entities(self, sql_query: str, cast_to: any = None) -> list[dict
111111
logging.info(f"Running query: {sql_query}")
112112
results = []
113113

114-
# Set up connection parameters for Databricks SQL endpoint
115-
connection = sql.connect(
116-
server_hostname=os.environ["Text2Sql__Databricks__ServerHostname"],
117-
http_path=os.environ["Text2Sql__Databricks__HttpPath"],
118-
access_token=os.environ["Text2Sql__Databricks__AccessToken"],
119-
)
120-
121-
try:
122-
# Create a cursor
123-
cursor = connection.cursor()
124-
125-
# Execute the query in a thread-safe manner
126-
await asyncio.to_thread(cursor.execute, sql_query)
127-
128-
# Fetch column names
129-
columns = [col[0] for col in cursor.description]
130-
131-
# Fetch rows
132-
rows = await asyncio.to_thread(cursor.fetchall)
133-
134-
# Process rows
135-
for row in rows:
136-
if cast_to:
137-
results.append(cast_to.from_sql_row(row, columns))
138-
else:
139-
results.append(dict(zip(columns, row)))
140-
141-
except Exception as e:
142-
logging.error(f"Error while executing query: {e}")
143-
raise
144-
finally:
145-
cursor.close()
146-
connection.close()
147-
148-
return results
114+
async with self.database_semaphore:
115+
# Set up connection parameters for Databricks SQL endpoint
116+
connection = sql.connect(
117+
server_hostname=os.environ["Text2Sql__Databricks__ServerHostname"],
118+
http_path=os.environ["Text2Sql__Databricks__HttpPath"],
119+
access_token=os.environ["Text2Sql__Databricks__AccessToken"],
120+
)
121+
122+
try:
123+
# Create a cursor
124+
cursor = connection.cursor()
125+
126+
# Execute the query in a thread-safe manner
127+
await asyncio.to_thread(cursor.execute, sql_query)
128+
129+
# Fetch column names
130+
columns = [col[0] for col in cursor.description]
131+
132+
# Fetch rows
133+
rows = await asyncio.to_thread(cursor.fetchall)
134+
135+
# Process rows
136+
for row in rows:
137+
if cast_to:
138+
results.append(cast_to.from_sql_row(row, columns))
139+
else:
140+
results.append(dict(zip(columns, row)))
141+
142+
except Exception as e:
143+
logging.error(f"Error while executing query: {e}")
144+
raise
145+
finally:
146+
cursor.close()
147+
connection.close()
148+
149+
return results
149150

150151

151152
if __name__ == "__main__":

text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/snowflake_data_dictionary_creator.py

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -94,40 +94,41 @@ async def query_entities(
9494
logging.info(f"Running query: {sql_query}")
9595
results = []
9696

97-
# Create a connection to Snowflake, without specifying a schema
98-
conn = snowflake.connector.connect(
99-
user=os.environ["Text2Sql__Snowflake__User"],
100-
password=os.environ["Text2Sql__Snowflake__Password"],
101-
account=os.environ["Text2Sql__Snowflake__Account"],
102-
warehouse=os.environ["Text2Sql__Snowflake__Warehouse"],
103-
database=os.environ["Text2Sql__DatabaseName"],
104-
)
105-
106-
try:
107-
# Using the connection to create a cursor
108-
cursor = conn.cursor()
109-
110-
# Execute the query
111-
await asyncio.to_thread(cursor.execute, sql_query)
112-
113-
# Fetch column names
114-
columns = [col[0] for col in cursor.description]
115-
116-
# Fetch rows
117-
rows = await asyncio.to_thread(cursor.fetchall)
118-
119-
# Process rows
120-
for row in rows:
121-
if cast_to:
122-
results.append(cast_to.from_sql_row(row, columns))
123-
else:
124-
results.append(dict(zip(columns, row)))
125-
126-
finally:
127-
cursor.close()
128-
conn.close()
129-
130-
return results
97+
async with self.database_semaphore:
98+
# Create a connection to Snowflake, without specifying a schema
99+
conn = snowflake.connector.connect(
100+
user=os.environ["Text2Sql__Snowflake__User"],
101+
password=os.environ["Text2Sql__Snowflake__Password"],
102+
account=os.environ["Text2Sql__Snowflake__Account"],
103+
warehouse=os.environ["Text2Sql__Snowflake__Warehouse"],
104+
database=os.environ["Text2Sql__DatabaseName"],
105+
)
106+
107+
try:
108+
# Using the connection to create a cursor
109+
cursor = conn.cursor()
110+
111+
# Execute the query
112+
await asyncio.to_thread(cursor.execute, sql_query)
113+
114+
# Fetch column names
115+
columns = [col[0] for col in cursor.description]
116+
117+
# Fetch rows
118+
rows = await asyncio.to_thread(cursor.fetchall)
119+
120+
# Process rows
121+
for row in rows:
122+
if cast_to:
123+
results.append(cast_to.from_sql_row(row, columns))
124+
else:
125+
results.append(dict(zip(columns, row)))
126+
127+
finally:
128+
cursor.close()
129+
conn.close()
130+
131+
return results
131132

132133

133134
if __name__ == "__main__":

0 commit comments

Comments
 (0)