Skip to content

Commit bb8acd5

Browse files
committed
Update validation
1 parent 82cbe03 commit bb8acd5

File tree

1 file changed

+15
-10
lines changed
  • text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors

1 file changed

+15
-10
lines changed

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -177,15 +177,19 @@ async def query_execution_with_limit(
177177
"""
178178

179179
# Validate the SQL query
180-
validation_result = await self.query_validation(sql_query)
180+
(
181+
validation_result,
182+
cleaned_query,
183+
validation_errors,
184+
) = await self.query_validation(sql_query)
181185

182-
if isinstance(validation_result, bool) and validation_result:
183-
result = await self.query_execution(sql_query, cast_to=None, limit=25)
186+
if validation_result and validation_errors is None:
187+
result = await self.query_execution(cleaned_query, cast_to=None, limit=25)
184188

185189
return json.dumps(
186190
{
187191
"type": "query_execution_with_limit",
188-
"sql_query": sql_query,
192+
"sql_query": cleaned_query,
189193
"sql_rows": result,
190194
},
191195
default=str,
@@ -194,8 +198,8 @@ async def query_execution_with_limit(
194198
return json.dumps(
195199
{
196200
"type": "errored_query_execution_with_limit",
197-
"sql_query": sql_query,
198-
"errors": validation_result,
201+
"sql_query": cleaned_query,
202+
"errors": validation_errors,
199203
},
200204
default=str,
201205
)
@@ -209,9 +213,10 @@ async def query_validation(
209213
) -> Union[bool | list[dict]]:
210214
"""Validate the SQL query."""
211215
try:
212-
logging.info("Validating SQL Query: %s", sql_query)
216+
cleaned_query = sql_query.strip().replace("\n", " ")
217+
logging.info("Validating SQL Query: %s", cleaned_query)
213218
parsed_queries = sqlglot.parse(
214-
sql_query,
219+
cleaned_query,
215220
read=self.database_engine.value.lower(),
216221
)
217222

@@ -255,10 +260,10 @@ def handle_node(node):
255260

256261
except sqlglot.errors.ParseError as e:
257262
logging.error("SQL Query is invalid: %s", e.errors)
258-
return e.errors
263+
return False, None, e.errors
259264
else:
260265
logging.info("SQL Query is valid.")
261-
return True
266+
return True, cleaned_query, None
262267

263268
async def fetch_sql_queries_with_schemas_from_cache(
264269
self, question: str, injected_parameters: dict = None

0 commit comments

Comments
 (0)