Skip to content

Commit 596686a

Browse files
committed
update new data loader design
1 parent 41bbf7a commit 596686a

File tree

9 files changed

+565
-36
lines changed

9 files changed

+565
-36
lines changed

py-src/data_formulator/agent_routes.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from data_formulator.agents.agent_data_load import DataLoadAgent
3030
from data_formulator.agents.agent_data_clean import DataCleanAgent
3131
from data_formulator.agents.agent_code_explanation import CodeExplanationAgent
32-
32+
from data_formulator.agents.agent_query_completion import QueryCompletionAgent
3333
from data_formulator.agents.client_utils import Client
3434

3535
from data_formulator.db_manager import db_manager
@@ -437,4 +437,25 @@ def request_code_expl():
437437
expl = code_expl_agent.run(input_tables, code)
438438
else:
439439
expl = ""
440-
return expl
440+
return expl
441+
442+
@agent_bp.route('/query-completion', methods=['POST'])
443+
def query_completion():
444+
if request.is_json:
445+
logger.info("# request data: ")
446+
content = request.get_json()
447+
448+
client = get_client(content['model'])
449+
450+
data_source_metadata = content["data_source_metadata"]
451+
query = content["query"]
452+
453+
454+
query_completion_agent = QueryCompletionAgent(client=client)
455+
reasoning, query = query_completion_agent.run(data_source_metadata, query)
456+
response = flask.jsonify({ "token": "", "status": "ok", "reasoning": reasoning, "query": query })
457+
else:
458+
response = flask.jsonify({ "token": "", "status": "error", "reasoning": "unable to complete query", "query": "" })
459+
460+
response.headers.add('Access-Control-Allow-Origin', '*')
461+
return response
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import pandas as pd
5+
import json
6+
7+
from data_formulator.agents.agent_utils import extract_code_from_gpt_response, extract_json_objects
8+
import re
9+
import logging
10+
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
SYSTEM_PROMPT = '''You are a data scientist to help with data queries.
16+
The user will provide you with a description of the data source and tables available in the [DATA SOURCE] section and a query in the [USER INPUTS] section.
17+
You will need to help the user complete the query and provide reasoning for the query you generated in the [OUTPUT] section.
18+
19+
Input format:
20+
* The data source description is a json object with the following fields:
21+
* `data_source`: the name of the data source
22+
* `tables`: a list of tables in the data source, which maps the table name to the list of columns available in the table.
23+
* The user input is a natural language description of the query or a partial query you need to complete.
24+
25+
Steps:
26+
* Based on data source description and user input, you should first decide on what language should be used to query the data.
27+
* Then, describe the logic for the query you generated in a json object in a block ```json``` with the following fields:
28+
* `language`: the language of the query you generated
29+
* `tables`: the names of the tables you will use in the query
30+
* `logic`: the reasoning behind why you chose the tables and the logic for the query you generated
31+
* Finally, generate the complete query in the language specified in a code block ```{language}```.
32+
33+
Output format:
34+
* The output should be in the following format, no other text should be included:
35+
36+
[REASONING]
37+
```json
38+
{
39+
"language": {language},
40+
"tables": {tables},
41+
"logic": {logic}
42+
}
43+
```
44+
45+
[QUERY]
46+
```{language}
47+
{query}
48+
```
49+
'''
50+
51+
class QueryCompletionAgent(object):
52+
53+
def __init__(self, client):
54+
self.client = client
55+
56+
def run(self, data_source_metadata, query):
57+
58+
user_query = f"[DATA SOURCE]\n\n{json.dumps(data_source_metadata, indent=2)}\n\n[USER INPUTS]\n\n{query}\n\n[REASONING]\n"
59+
60+
logger.info(user_query)
61+
62+
messages = [{"role":"system", "content": SYSTEM_PROMPT},
63+
{"role":"user","content": user_query}]
64+
65+
###### the part that calls open_ai
66+
response = self.client.get_completion(messages = messages)
67+
response_content = '[REASONING]\n' + response.choices[0].message.content
68+
69+
logger.info(f"=== query completion output ===>\n{response_content}\n")
70+
71+
reasoning = extract_json_objects(response_content.split("[REASONING]")[1].split("[QUERY]")[0].strip())[0]
72+
output_query = response_content.split("[QUERY]")[1].strip()
73+
74+
# Extract the query by removing the language markers
75+
language_pattern = r"```(\w+)\s+(.*?)```"
76+
match = re.search(language_pattern, output_query, re.DOTALL)
77+
if match:
78+
output_query = match.group(2).strip()
79+
80+
return reasoning, output_query

py-src/data_formulator/data_loader/external_data_loader.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,41 @@
55
import duckdb
66
import random
77
import string
8+
import re
9+
10+
def sanitize_table_name(name_as: str) -> str:
11+
if not name_as:
12+
raise ValueError("Table name cannot be empty")
13+
14+
# Remove any SQL injection attempts
15+
name_as = name_as.replace(";", "").replace("--", "").replace("/*", "").replace("*/", "")
16+
17+
# Replace invalid characters with underscores
18+
# This includes special characters, spaces, dots, dashes, and other non-alphanumeric chars
19+
sanitized = re.sub(r'[^a-zA-Z0-9_]', '_', name_as)
20+
21+
# Ensure the name starts with a letter or underscore
22+
if not sanitized[0].isalpha() and sanitized[0] != '_':
23+
sanitized = '_' + sanitized
24+
25+
# Ensure the name is not a SQL keyword
26+
sql_keywords = {
27+
'SELECT', 'FROM', 'WHERE', 'GROUP', 'BY', 'ORDER', 'HAVING', 'LIMIT',
28+
'OFFSET', 'JOIN', 'INNER', 'LEFT', 'RIGHT', 'FULL', 'OUTER', 'ON',
29+
'AND', 'OR', 'NOT', 'NULL', 'TRUE', 'FALSE', 'UNION', 'ALL', 'DISTINCT',
30+
'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'DROP', 'TABLE', 'VIEW', 'INDEX',
31+
'ALTER', 'ADD', 'COLUMN', 'PRIMARY', 'KEY', 'FOREIGN', 'REFERENCES',
32+
'CONSTRAINT', 'DEFAULT', 'CHECK', 'UNIQUE', 'CASCADE', 'RESTRICT'
33+
}
34+
35+
if sanitized.upper() in sql_keywords:
36+
sanitized = '_' + sanitized
37+
38+
# Ensure the name is not too long (common SQL limit is 63 characters)
39+
if len(sanitized) > 63:
40+
sanitized = sanitized[:63]
41+
42+
return sanitized
843

944
class ExternalDataLoader(ABC):
1045

@@ -45,6 +80,10 @@ def list_tables(self) -> List[Dict[str, Any]]:
4580
def ingest_data(self, table_name: str, name_as: str = None, size: int = 1000000):
4681
pass
4782

83+
@abstractmethod
84+
def view_query_sample(self, query: str) -> str:
85+
pass
86+
4887
@abstractmethod
4988
def ingest_data_from_query(self, query: str, name_as: str):
5089
pass

py-src/data_formulator/data_loader/kusto_data_loader.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,8 @@
88
from azure.kusto.data import KustoClient, KustoConnectionStringBuilder
99
from azure.kusto.data.helpers import dataframe_from_result_table
1010

11-
from data_formulator.data_loader.external_data_loader import ExternalDataLoader
11+
from data_formulator.data_loader.external_data_loader import ExternalDataLoader, sanitize_table_name
1212

13-
def sanitize_table_name(table_name: str) -> str:
14-
return table_name.replace(".", "_").replace("-", "_")
1513

1614
class KustoDataLoader(ExternalDataLoader):
1715

@@ -53,8 +51,6 @@ def query(self, kql: str) -> pd.DataFrame:
5351
return dataframe_from_result_table(result.primary_results[0])
5452

5553
def list_tables(self) -> List[Dict[str, Any]]:
56-
57-
5854
# first list functions (views)
5955
query = ".show functions"
6056
function_result_df = self.query(query)
@@ -170,6 +166,8 @@ def ingest_data(self, table_name: str, name_as: str = None, size: int = 5000000)
170166

171167
total_rows_ingested += len(chunk_df)
172168

169+
def view_query_sample(self, query: str) -> str:
170+
return self.query(query).head(10).to_dict(orient="records")
173171

174172
def ingest_data_from_query(self, query: str, name_as: str) -> pd.DataFrame:
175173
# Sanitize the table name for SQL compatibility

py-src/data_formulator/data_loader/mysql_data_loader.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pandas as pd
44
import duckdb
55

6-
from data_formulator.data_loader.external_data_loader import ExternalDataLoader
6+
from data_formulator.data_loader.external_data_loader import ExternalDataLoader, sanitize_table_name
77
from typing import Dict, Any
88

99
class MySQLDataLoader(ExternalDataLoader):
@@ -31,6 +31,12 @@ def __init__(self, params: Dict[str, Any], duck_db_conn: duckdb.DuckDBPyConnecti
3131
if value:
3232
attatch_string += f"{key}={value} "
3333

34+
# Detach existing mysqldb connection if it exists
35+
try:
36+
self.duck_db_conn.execute("DETACH mysqldb;")
37+
except:
38+
pass # Ignore if mysqldb doesn't exist
39+
3440
# Register MySQL connection
3541
self.duck_db_conn.execute(f"ATTACH '{attatch_string}' AS mysqldb (TYPE mysql);")
3642

@@ -44,21 +50,21 @@ def list_tables(self):
4450

4551
for schema, table_name in tables_df.values:
4652

47-
full_table_name = f"{schema}.{table_name}"
53+
full_table_name = f"mysqldb.{schema}.{table_name}"
4854

4955
# Get column information using DuckDB's information schema
50-
columns_df = self.duck_db_conn.execute(f"DESCRIBE mysqldb.{full_table_name}").df()
56+
columns_df = self.duck_db_conn.execute(f"DESCRIBE {full_table_name}").df()
5157
columns = [{
5258
'name': row['column_name'],
5359
'type': row['column_type']
5460
} for _, row in columns_df.iterrows()]
5561

5662
# Get sample data
57-
sample_df = self.duck_db_conn.execute(f"SELECT * FROM mysqldb.{full_table_name} LIMIT 10").df()
63+
sample_df = self.duck_db_conn.execute(f"SELECT * FROM {full_table_name} LIMIT 10").df()
5864
sample_rows = json.loads(sample_df.to_json(orient="records"))
5965

6066
# get row count
61-
row_count = self.duck_db_conn.execute(f"SELECT COUNT(*) FROM mysqldb.{full_table_name}").fetchone()[0]
67+
row_count = self.duck_db_conn.execute(f"SELECT COUNT(*) FROM {full_table_name}").fetchone()[0]
6268

6369
table_metadata = {
6470
"row_count": row_count,
@@ -73,19 +79,24 @@ def list_tables(self):
7379

7480
return results
7581

76-
def ingest_data(self, table_name: str, name_as: str = None, size: int = 1000000):
82+
def ingest_data(self, table_name: str, name_as: str | None = None, size: int = 1000000):
7783
# Create table in the main DuckDB database from MySQL data
7884
if name_as is None:
7985
name_as = table_name.split('.')[-1]
8086

87+
name_as = sanitize_table_name(name_as)
88+
8189
self.duck_db_conn.execute(f"""
82-
CREATE OR REPLACE TABLE {name_as} AS
83-
SELECT * FROM mysqldb.{table_name}
90+
CREATE OR REPLACE TABLE main.{name_as} AS
91+
SELECT * FROM {table_name}
8492
LIMIT {size}
8593
""")
8694

95+
def view_query_sample(self, query: str) -> str:
96+
return self.duck_db_conn.execute(query).df().head(10).to_dict(orient="records")
97+
8798
def ingest_data_from_query(self, query: str, name_as: str) -> pd.DataFrame:
88-
self.duck_db_conn.execute(f"""
89-
CREATE OR REPLACE TABLE main.{name_as} AS
90-
SELECT * FROM ({query})
91-
""")
99+
# Execute the query and get results as a DataFrame
100+
df = self.duck_db_conn.execute(query).df()
101+
# Use the base class's method to ingest the DataFrame
102+
self.ingest_df_to_duckdb(df, name_as)

py-src/data_formulator/tables_routes.py

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -691,17 +691,20 @@ def sanitize_db_error_message(error: Exception) -> Tuple[str, int]:
691691
# Define patterns for known safe errors
692692
safe_error_patterns = {
693693
# Database table errors
694-
r"Table.*does not exist": ("Specified table was not found", 404),
695-
r"Table.*already exists": ("A table with this name already exists", 409),
694+
r"Table.*does not exist": (error_msg, 404),
695+
r"Table.*already exists": (error_msg, 409),
696696
# Query errors
697-
r"syntax error in SQL": ("Invalid SQL query syntax", 400),
698-
r"Invalid input syntax": ("Invalid input data format", 400),
697+
r"syntax error": (error_msg, 400),
698+
r"Catalog Error": (error_msg, 404),
699+
r"Binder Error": (error_msg, 400),
700+
r"Invalid input syntax": (error_msg, 400),
701+
699702
# File errors
700-
r"No such file": ("File not found", 404),
703+
r"No such file": (error_msg, 404),
701704
r"Permission denied": ("Access denied", 403),
702705

703706
# Data loader errors
704-
r"Entity ID": ("Entity ID not found, please check the data loader parameters", 500),
707+
r"Entity ID": (error_msg, 500),
705708
r"session_id": ("session_id not found, please refresh the page", 500),
706709
}
707710

@@ -790,6 +793,70 @@ def data_loader_ingest_data():
790793
"message": "Successfully ingested data from data loader"
791794
})
792795

796+
except Exception as e:
797+
logger.error(f"Error ingesting data from data loader: {str(e)}")
798+
safe_msg, status_code = sanitize_db_error_message(e)
799+
return jsonify({
800+
"status": "error",
801+
"message": safe_msg
802+
}), status_code
803+
804+
805+
@tables_bp.route('/data-loader/view-query-sample', methods=['POST'])
806+
def data_loader_view_query_sample():
807+
"""View a sample of data from a query"""
808+
809+
try:
810+
data = request.get_json()
811+
data_loader_type = data.get('data_loader_type')
812+
data_loader_params = data.get('data_loader_params')
813+
query = data.get('query')
814+
815+
if data_loader_type not in DATA_LOADERS:
816+
return jsonify({"status": "error", "message": f"Invalid data loader type. Must be one of: {', '.join(DATA_LOADERS.keys())}"}), 400
817+
818+
with db_manager.connection(session['session_id']) as duck_db_conn:
819+
data_loader = DATA_LOADERS[data_loader_type](data_loader_params, duck_db_conn)
820+
sample = data_loader.view_query_sample(query)
821+
822+
return jsonify({
823+
"status": "success",
824+
"sample": sample,
825+
"message": "Successfully retrieved query sample"
826+
})
827+
except Exception as e:
828+
logger.error(f"Error viewing query sample: {str(e)}")
829+
safe_msg, status_code = sanitize_db_error_message(e)
830+
return jsonify({
831+
"status": "error",
832+
"sample": [],
833+
"message": safe_msg
834+
}), status_code
835+
836+
837+
@tables_bp.route('/data-loader/ingest-data-from-query', methods=['POST'])
838+
def data_loader_ingest_data_from_query():
839+
"""Ingest data from a data loader"""
840+
841+
try:
842+
data = request.get_json()
843+
data_loader_type = data.get('data_loader_type')
844+
data_loader_params = data.get('data_loader_params')
845+
query = data.get('query')
846+
name_as = data.get('name_as')
847+
848+
if data_loader_type not in DATA_LOADERS:
849+
return jsonify({"status": "error", "message": f"Invalid data loader type. Must be one of: {', '.join(DATA_LOADERS.keys())}"}), 400
850+
851+
with db_manager.connection(session['session_id']) as duck_db_conn:
852+
data_loader = DATA_LOADERS[data_loader_type](data_loader_params, duck_db_conn)
853+
data_loader.ingest_data_from_query(query, name_as)
854+
855+
return jsonify({
856+
"status": "success",
857+
"message": "Successfully ingested data from data loader"
858+
})
859+
793860
except Exception as e:
794861
logger.error(f"Error ingesting data from data loader: {str(e)}")
795862
safe_msg, status_code = sanitize_db_error_message(e)

src/app/utils.tsx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ export function getUrls() {
6969
DATA_LOADER_LIST_DATA_LOADERS: `/api/tables/data-loader/list-data-loaders`,
7070
DATA_LOADER_LIST_TABLES: `/api/tables/data-loader/list-tables`,
7171
DATA_LOADER_INGEST_DATA: `/api/tables/data-loader/ingest-data`,
72+
DATA_LOADER_VIEW_QUERY_SAMPLE: `/api/tables/data-loader/view-query-sample`,
73+
DATA_LOADER_INGEST_DATA_FROM_QUERY: `/api/tables/data-loader/ingest-data-from-query`,
74+
75+
QUERY_COMPLETION: `/api/agent/query-completion`,
7276
};
7377
}
7478

0 commit comments

Comments
 (0)