Skip to content

Commit 2731397

Browse files
committed
ready for version 0.2
1 parent 44b6373 commit 2731397

File tree

11 files changed

+227
-124
lines changed

11 files changed

+227
-124
lines changed

.env.template

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ DISABLE_DISPLAY_KEYS=false # if true, the display keys will not be shown in the
66
EXEC_PYTHON_IN_SUBPROCESS=false # if true, the python code will be executed in a subprocess to avoid crashing the main app, but it will increase the time of response
77

88
# External atabase connection settings
9-
# check https://duckdb.org/docs/stable/extensions/mysql.html and https://duckdb.org/docs/stable/extensions/postgres.html
9+
# check https://duckdb.org/docs/stable/extensions/mysql.html
10+
# and https://duckdb.org/docs/stable/extensions/postgres.html
1011
USE_EXTERNAL_DB=false # if true, the app will use an external database instead of the one in the app
1112
DB_NAME=mysql_db # the name to refer to this database connection
1213
DB_TYPE=mysql # mysql or postgresql

py-src/data_formulator/agents/agent_sql_data_rec.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import json
55

66
from data_formulator.agents.agent_utils import extract_json_objects, extract_code_from_gpt_response
7-
from data_formulator.agents.agent_sql_data_transform import get_sql_table_statistics_str
7+
from data_formulator.agents.agent_sql_data_transform import get_sql_table_statistics_str, sanitize_table_name
88

99
import random
1010
import string
@@ -64,6 +64,10 @@
6464
3. The [OUTPUT] must only contain two items:
6565
- a json object (wrapped in ```json```) representing the refined goal (including "mode", "recommendation", "output_fields", "chart_type", "visualization_fields")
6666
- a sql query block (wrapped in ```sql```) representing the transformation code, do not add any extra text explanation.
67+
68+
some notes:
69+
- in DuckDB, you escape a single quote within a string by doubling it ('') rather than using a backslash (\').
70+
- in DuckDB, you need to use proper date functions to perform date operations.
6771
'''
6872

6973
example = """
@@ -167,21 +171,17 @@ def process_gpt_response(self, input_tables, messages, response):
167171
row_count = self.conn.execute(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0]
168172

169173
# Only limit to 5000 if there are more rows
170-
if row_count > 5000:
171-
query_output = self.conn.execute(f"SELECT * FROM {table_name} LIMIT 5000").fetch_df()
172-
else:
173-
query_output = self.conn.execute(f"SELECT * FROM {table_name}").fetch_df()
174-
self.conn.execute(f"DROP VIEW {table_name}")
174+
query_output = self.conn.execute(f"SELECT * FROM {table_name} LIMIT 5000").fetch_df()
175175

176176
result = {
177177
"status": "ok",
178178
"code": code_str,
179179
"content": {
180-
'rows': query_output.to_dict('records'),
180+
'rows': json.loads(query_output.to_json(orient='records')),
181181
'virtual': {
182182
'table_name': table_name,
183183
'row_count': row_count
184-
} if row_count > 5000 else None
184+
}
185185
},
186186
}
187187
except Exception as e:
@@ -211,8 +211,9 @@ def process_gpt_response(self, input_tables, messages, response):
211211
def run(self, input_tables, description, n=1):
212212
data_summary = ""
213213
for table in input_tables:
214-
table_summary_str = get_sql_table_statistics_str(self.conn, table['name'])
215-
data_summary += f"[TABLE {table['name']}]\n\n{table_summary_str}\n\n"
214+
table_name = sanitize_table_name(table['name'])
215+
table_summary_str = get_sql_table_statistics_str(self.conn, table_name)
216+
data_summary += f"[TABLE {table_name}]\n\n{table_summary_str}\n\n"
216217

217218
user_query = f"[CONTEXT]\n\n{data_summary}\n\n[GOAL]\n\n{description}\n\n[OUTPUT]\n"
218219

py-src/data_formulator/agents/agent_sql_data_transform.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
import pandas as pd
1010

1111
import logging
12-
12+
import re
1313
# Replace/update the logger configuration
1414
logger = logging.getLogger(__name__)
1515

1616
SYSTEM_PROMPT = '''You are a data scientist to help user to transform data that will be used for visualization.
1717
The user will provide you information about what data would be needed, and your job is to create a sql query based on the input data summary, transformation instruction and expected fields.
18-
The users' instruction includes "expected fields" that the user want for visualization, and natural language instructions "goal" that describe what data is needed.
18+
The users' instruction includes "visualization_fields" that the user want for visualization, and natural language instructions "goal" that describe what data is needed.
1919
2020
**Important:**
2121
- NEVER make assumptions or judgments about a person's gender, biological sex, sexuality, religion, race, nationality, ethnicity, political stance, socioeconomic status, mental health, invisible disabilities, medical conditions, personality type, social impressions, emotional state, and cognitive state.
@@ -24,15 +24,22 @@
2424
2525
Concretely, you should first refine users' goal and then create a sql query in the [OUTPUT] section based off the [CONTEXT] and [GOAL]:
2626
27-
1. First, refine users' [GOAL]. The main objective in this step is to check if "visualization_fields" provided by the user are sufficient to achieve their "goal". Concretely:
28-
(1) based on the user's "goal", elaborate the goal into a "detailed_instruction".
27+
1. First, refine users' [GOAL]. The main objective in this step is to decide data transformation based on the user's goal.
28+
Concretely:
29+
(1) based on the user's "goal" and provided "visualization_fields", elaborate the goal into a "detailed_instruction".
30+
- first elaborate which fields the user wants to visualize based on "visualization_fields";
31+
- then, elaborate the goal into a "detailed_instruction" contextualized with the provided "visualization_fields".
32+
* note: try to distinguish whether the user wants to fitler the data with some conditions, or they want to aggregate data based on some fields.
33+
* e.g., filter data to show all items from top 20 categories based on their average values, is different from showing the top 20 categories with their average values
2934
(2) determine "output_fields", the desired fields that the output data should have to achieve the user's goal, it's a good idea to include intermediate fields here.
30-
(2) now, determine whether the user has provided sufficient fields in "visualization_fields" that are needed to achieve their goal:
31-
- if the user's "visualization_fields" are sufficient, simply copy it.
35+
- note: when the user asks for filtering the data, include all fields that are needed to filter the data in "output_fields" (as well as other fields the user asked for or necessary in computation).
36+
(3) now, determine whether the user has provided sufficient fields in "visualization_fields" that are needed to achieve their goal:
37+
- if the user's "visualization_fields" are sufficient, simply copy it from user input.
3238
- if the user didn't provide sufficient fields in "visualization_fields", add missing fields in "visualization_fields" (ordered them based on whether the field will be used in x,y axes or legends);
3339
- "visualization_fields" should only include fields that will be visualized (do not include other intermediate fields from "output_fields")
3440
- when adding new fields to "visualization_fields", be efficient and add only a minimal number of fields that are needed to achive the user's goal. generally, the total number of fields in "visualization_fields" should be no more than 3 for x,y,legend.
35-
41+
- if the user's goal is to filter the data, include all fields that are needed to filter the data in "output_fields" (as well as other fields the user asked for or necessary in computation).
42+
- all existing fields user provided in "visualization_fields" should be included in "visualization_fields" list.
3643
Prepare the result in the following json format:
3744
3845
```
@@ -52,6 +59,10 @@
5259
3. The [OUTPUT] must only contain two items:
5360
- a json object (wrapped in ```json```) representing the refined goal (including "detailed_instruction", "output_fields", "visualization_fields" and "reason")
5461
- a sql query block (wrapped in ```sql```) representing the transformation code, do not add any extra text explanation.
62+
63+
some notes:
64+
- in DuckDB, you escape a single quote within a string by doubling it ('') rather than using a backslash (\').
65+
- in DuckDB, you need to use proper date functions to perform date operations.
5566
'''
5667

5768
EXAMPLE='''
@@ -104,6 +115,15 @@
104115
```
105116
'''
106117

118+
def sanitize_table_name(table_name: str) -> str:
119+
"""Sanitize table name to be used in SQL queries"""
120+
# Replace spaces with underscores
121+
sanitized_name = table_name.replace(" ", "_")
122+
sanitized_name = sanitized_name.replace("-", "_")
123+
# Allow alphanumeric, underscore, dot, dash, and dollar sign
124+
sanitized_name = re.sub(r'[^a-zA-Z0-9_\.$]', '', sanitized_name)
125+
return sanitized_name
126+
107127
class SQLDataTransformationAgent(object):
108128

109129
def __init__(self, client, conn, system_prompt=None):
@@ -156,17 +176,16 @@ def process_gpt_sql_response(self, response, messages):
156176
query_output = self.conn.execute(f"SELECT * FROM {table_name} LIMIT 5000").fetch_df()
157177
else:
158178
query_output = self.conn.execute(f"SELECT * FROM {table_name}").fetch_df()
159-
self.conn.execute(f"DROP VIEW {table_name}")
160179

161180
result = {
162181
"status": "ok",
163182
"code": query_str,
164183
"content": {
165-
'rows': query_output.to_dict('records'),
184+
'rows': json.loads(query_output.to_json(orient='records')),
166185
'virtual': {
167186
'table_name': table_name,
168187
'row_count': row_count
169-
} if row_count > 5000 else None
188+
}
170189
},
171190
}
172191

@@ -205,19 +224,24 @@ def run(self, input_tables, description, expected_fields: list[str], prev_messag
205224
"""
206225

207226
for table in input_tables:
208-
table_name = table['name']
227+
table_name = sanitize_table_name(table['name'])
228+
209229
# Check if table exists in the connection
210230
try:
211231
self.conn.execute(f"DESCRIBE {table_name}")
212232
except Exception:
213233
# Table doesn't exist, create it from the dataframe
214234
df = pd.DataFrame(table['rows'])
235+
215236
# Register the dataframe as a temporary view
216-
self.conn.register(f'df_temp_{table_name}', df)
237+
self.conn.register(f'df_temp', df)
217238
# Create a permanent table from the temporary view
218-
self.conn.execute(f"CREATE VIEW {table_name} AS SELECT * FROM df_temp_{table_name}")
239+
self.conn.execute(f"CREATE TABLE {table_name} AS SELECT * FROM df_temp")
219240
# Drop the temporary view
220-
self.conn.execute(f"DROP VIEW df_temp_{table_name}")
241+
self.conn.execute(f"DROP VIEW df_temp")
242+
243+
r = self.conn.execute(f"SELECT * FROM {table_name} LIMIT 10").fetch_df()
244+
print(r)
221245
# Log the creation of the table
222246
logger.info(f"Created table {table_name} from dataframe")
223247

@@ -232,8 +256,9 @@ def run(self, input_tables, description, expected_fields: list[str], prev_messag
232256

233257
data_summary = ""
234258
for table in input_tables:
235-
table_summary_str = get_sql_table_statistics_str(self.conn, table['name'])
236-
data_summary += f"[TABLE {table['name']}]\n\n{table_summary_str}\n\n"
259+
table_name = sanitize_table_name(table['name'])
260+
table_summary_str = get_sql_table_statistics_str(self.conn, table_name)
261+
data_summary += f"[TABLE {table_name}]\n\n{table_summary_str}\n\n"
237262

238263
goal = {
239264
"instruction": description,
@@ -276,6 +301,9 @@ def followup(self, input_tables, dialog, output_fields: list[str], new_instructi
276301

277302

278303
def get_sql_table_statistics_str(conn, table_name: str) -> str:
304+
"""Get a string representation of the table statistics"""
305+
306+
table_name = sanitize_table_name(table_name)
279307

280308
# Get column information
281309
columns = conn.execute(f"DESCRIBE {table_name}").fetchall()

py-src/data_formulator/db_manager.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@ class DuckDBManager:
1212
def __init__(self, external_db_connections: Dict[str, Dict[str, Any]]):
1313
# Store session db file paths
1414
self._db_files: Dict[str, str] = {}
15-
# Track which extensions have been installed for which db files
16-
self._installed_extensions: Dict[str, List[str]] = {}
1715

18-
# external db connections
16+
# External db connections and tracking of installed extensions
1917
self._external_db_connections: Dict[str, Dict[str, Any]] = external_db_connections
20-
18+
self._installed_extensions: Dict[str, List[str]] = {}
19+
2120
@contextmanager
2221
def connection(self, session_id: str) -> ContextManager[duckdb.DuckDBPyConnection]:
2322
"""Get a DuckDB connection as a context manager that will be closed when exiting the context"""
@@ -34,7 +33,7 @@ def get_connection(self, session_id: str) -> duckdb.DuckDBPyConnection:
3433
"""Internal method to get or create a DuckDB connection for a session"""
3534
# Get or create the db file path for this session
3635
if session_id not in self._db_files or self._db_files[session_id] is None:
37-
db_file = os.path.join(tempfile.gettempdir(), f"df_{session_id}.db")
36+
db_file = os.path.join(tempfile.gettempdir(), f"df_{session_id}.duckdb")
3837
print(f"=== Creating new db file: {db_file}")
3938
self._db_files[session_id] = db_file
4039
# Initialize extension tracking for this file

py-src/data_formulator/tables_routes.py

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -41,35 +41,52 @@ def list_tables():
4141
try:
4242
result = []
4343
with db_manager.connection(session['session_id']) as db:
44-
table_metadata_list = db.execute("SELECT database_name, schema_name, table_name, schema_name==current_schema() as is_current_schema FROM duckdb_tables() WHERE internal=False").fetchall()
44+
table_metadata_list = db.execute("""
45+
SELECT database_name, schema_name, table_name, schema_name==current_schema() as is_current_schema, 'table' as object_type
46+
FROM duckdb_tables()
47+
WHERE internal=False
48+
UNION ALL
49+
SELECT database_name, schema_name, view_name as table_name, schema_name==current_schema() as is_current_schema, 'view' as object_type
50+
FROM duckdb_views()
51+
WHERE view_name NOT LIKE 'duckdb_%' AND view_name NOT LIKE 'sqlite_%' AND view_name NOT LIKE 'pragma_%'
52+
""").fetchall()
4553

46-
print(f"table_metadata_list: {table_metadata_list}")
54+
4755
for table_metadata in table_metadata_list:
48-
[database_name, schema_name, table_name, is_current_schema] = table_metadata
49-
56+
[database_name, schema_name, table_name, is_current_schema, object_type] = table_metadata
5057
table_name = table_name if is_current_schema else '.'.join([database_name, schema_name, table_name])
51-
# Get column information
52-
columns = db.execute(f"DESCRIBE {table_name}").fetchall()
53-
# Get row count
54-
row_count = db.execute(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0]
55-
sample_rows = db.execute(f"SELECT * FROM {table_name} LIMIT 1000").fetchdf()
58+
if database_name in ['system', 'temp']:
59+
continue
60+
5661

57-
# Check if this is a view or a table
62+
print(f"table_metadata: {table_metadata}")
63+
5864
try:
59-
# Get both view existence and source in one query
60-
view_info = db.execute(f"SELECT view_name, sql FROM duckdb_views() WHERE view_name = '{table_name}'").fetchone()
61-
view_source = view_info[1] if view_info else None
65+
# Get column information
66+
columns = db.execute(f"DESCRIBE {table_name}").fetchall()
67+
# Get row count
68+
row_count = db.execute(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0]
69+
sample_rows = db.execute(f"SELECT * FROM {table_name} LIMIT 1000").fetchdf()
70+
71+
# Check if this is a view or a table
72+
try:
73+
# Get both view existence and source in one query
74+
view_info = db.execute(f"SELECT view_name, sql FROM duckdb_views() WHERE view_name = '{table_name}'").fetchone()
75+
view_source = view_info[1] if view_info else None
76+
except Exception as e:
77+
# If the query fails, assume it's a regular table
78+
view_source = None
79+
80+
result.append({
81+
"name": table_name,
82+
"columns": [{"name": col[0], "type": col[1]} for col in columns],
83+
"row_count": row_count,
84+
"sample_rows": json.loads(sample_rows.to_json(orient='records')),
85+
"view_source": view_source
86+
})
6287
except Exception as e:
63-
# If the query fails, assume it's a regular table
64-
view_source = None
65-
66-
result.append({
67-
"name": table_name,
68-
"columns": [{"name": col[0], "type": col[1]} for col in columns],
69-
"row_count": row_count,
70-
"sample_rows": json.loads(sample_rows.to_json(orient='records')),
71-
"view_source": view_source
72-
})
88+
logger.error(f"Error getting table metadata for {table_name}: {str(e)}")
89+
continue
7390

7491
return jsonify({
7592
"status": "success",
@@ -157,6 +174,8 @@ def sample_table():
157174
with db_manager.connection(session['session_id']) as db:
158175
# Get valid column names
159176
columns = [col[0] for col in db.execute(f"DESCRIBE {table_id}").fetchall()]
177+
178+
print(f"columns: {columns}")
160179

161180
# Filter order_by_fields to only include valid column names
162181
valid_order_by_fields = [field for field in order_by_fields if field in columns]
@@ -168,11 +187,16 @@ def sample_table():
168187

169188
query, output_column_names = assemble_query(valid_aggregate_fields_and_functions, valid_select_fields, columns, table_id)
170189

190+
print(f"query: {query}")
191+
print(f"output_column_names: {output_column_names}")
192+
171193
# Modify the original query to include the count:
172194
count_query = f"SELECT *, COUNT(*) OVER () as total_count FROM ({query}) as subq LIMIT 1"
173195
result = db.execute(count_query).fetchone()
174196
total_row_count = result[-1] if result else 0
175197

198+
print(f"total_row_count: {total_row_count}")
199+
176200
# Add ordering and limit to the main query
177201
if method == 'random':
178202
query += f" ORDER BY RANDOM() LIMIT {sample_size}"
@@ -191,8 +215,12 @@ def sample_table():
191215
else:
192216
query += f" ORDER BY ROWID DESC LIMIT {sample_size}"
193217

218+
print(f"query: {query}")
219+
194220
result = db.execute(query).fetchdf()
195221

222+
print(f"result: {result}")
223+
196224
return jsonify({
197225
"status": "success",
198226
"rows": json.loads(result.to_json(orient='records')),

0 commit comments

Comments
 (0)