Skip to content

Commit 757726f

Browse files
committed
various fixes and planning for SSE support
1 parent a830136 commit 757726f

File tree

9 files changed

+225
-208
lines changed

9 files changed

+225
-208
lines changed

py-src/data_formulator/db_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
14
import duckdb
25
import pandas as pd
36
from typing import Dict

py-src/data_formulator/tables_routes.py

Lines changed: 113 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import random
1515
import string
1616
from pathlib import Path
17+
import uuid
1718

1819
from data_formulator.db_manager import db_manager
1920
from data_formulator.data_loader import DATA_LOADERS
@@ -36,62 +37,72 @@
3637

3738
tables_bp = Blueprint('tables', __name__, url_prefix='/api/tables')
3839

40+
41+
def list_tables_util(db_conn):
42+
"""
43+
List all tables in the current session
44+
"""
45+
results = []
46+
47+
table_metadata_list = db_conn.execute("""
48+
SELECT database_name, schema_name, table_name, schema_name==current_schema() as is_current_schema, 'table' as object_type
49+
FROM duckdb_tables()
50+
WHERE internal=False AND database_name == current_database()
51+
UNION ALL
52+
SELECT database_name, schema_name, view_name as table_name, schema_name==current_schema() as is_current_schema, 'view' as object_type
53+
FROM duckdb_views()
54+
WHERE view_name NOT LIKE 'duckdb_%' AND view_name NOT LIKE 'sqlite_%' AND view_name NOT LIKE 'pragma_%' AND database_name == current_database()
55+
""").fetchall()
56+
57+
for table_metadata in table_metadata_list:
58+
[database_name, schema_name, table_name, is_current_schema, object_type] = table_metadata
59+
table_name = table_name if is_current_schema else '.'.join([database_name, schema_name, table_name])
60+
if database_name in ['system', 'temp']:
61+
continue
62+
63+
print(f"table_metadata: {table_metadata}")
64+
65+
try:
66+
# Get column information
67+
columns = db_conn.execute(f"DESCRIBE {table_name}").fetchall()
68+
69+
# Get row count
70+
row_count = db_conn.execute(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0]
71+
sample_rows = db_conn.execute(f"SELECT * FROM {table_name} LIMIT 1000").fetchdf() if row_count > 0 else pd.DataFrame()
72+
73+
# Check if this is a view or a table
74+
try:
75+
# Get both view existence and source in one query
76+
view_info = db_conn.execute(f"SELECT view_name, sql FROM duckdb_views() WHERE view_name = '{table_name}'").fetchone()
77+
view_source = view_info[1] if view_info else None
78+
except Exception as e:
79+
# If the query fails, assume it's a regular table
80+
view_source = None
81+
82+
results.append({
83+
"name": table_name,
84+
"columns": [{"name": col[0], "type": col[1]} for col in columns],
85+
"row_count": row_count,
86+
"sample_rows": json.loads(sample_rows.to_json(orient='records')),
87+
"view_source": view_source
88+
})
89+
except Exception as e:
90+
logger.error(f"Error getting table metadata for {table_name}: {str(e)}")
91+
continue
92+
93+
return results
94+
3995
@tables_bp.route('/list-tables', methods=['GET'])
4096
def list_tables():
4197
"""List all tables in the current session"""
4298
try:
43-
result = []
4499
with db_manager.connection(session['session_id']) as db:
45-
table_metadata_list = db.execute("""
46-
SELECT database_name, schema_name, table_name, schema_name==current_schema() as is_current_schema, 'table' as object_type
47-
FROM duckdb_tables()
48-
WHERE internal=False AND database_name == current_database()
49-
UNION ALL
50-
SELECT database_name, schema_name, view_name as table_name, schema_name==current_schema() as is_current_schema, 'view' as object_type
51-
FROM duckdb_views()
52-
WHERE view_name NOT LIKE 'duckdb_%' AND view_name NOT LIKE 'sqlite_%' AND view_name NOT LIKE 'pragma_%' AND database_name == current_database()
53-
""").fetchall()
100+
results = list_tables_util(db)
54101

55-
56-
for table_metadata in table_metadata_list:
57-
[database_name, schema_name, table_name, is_current_schema, object_type] = table_metadata
58-
table_name = table_name if is_current_schema else '.'.join([database_name, schema_name, table_name])
59-
if database_name in ['system', 'temp']:
60-
continue
61-
62-
print(f"table_metadata: {table_metadata}")
63-
64-
try:
65-
# Get column information
66-
columns = db.execute(f"DESCRIBE {table_name}").fetchall()
67-
# Get row count
68-
row_count = db.execute(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0]
69-
sample_rows = db.execute(f"SELECT * FROM {table_name} LIMIT 1000").fetchdf()
70-
71-
# Check if this is a view or a table
72-
try:
73-
# Get both view existence and source in one query
74-
view_info = db.execute(f"SELECT view_name, sql FROM duckdb_views() WHERE view_name = '{table_name}'").fetchone()
75-
view_source = view_info[1] if view_info else None
76-
except Exception as e:
77-
# If the query fails, assume it's a regular table
78-
view_source = None
79-
80-
result.append({
81-
"name": table_name,
82-
"columns": [{"name": col[0], "type": col[1]} for col in columns],
83-
"row_count": row_count,
84-
"sample_rows": json.loads(sample_rows.to_json(orient='records')),
85-
"view_source": view_source
86-
})
87-
except Exception as e:
88-
logger.error(f"Error getting table metadata for {table_name}: {str(e)}")
89-
continue
90-
91-
return jsonify({
92-
"status": "success",
93-
"tables": result
94-
})
102+
return jsonify({
103+
"status": "success",
104+
"tables": results
105+
})
95106
except Exception as e:
96107
logger.error(f"Error listing tables: {str(e)}")
97108
safe_msg, status_code = sanitize_db_error_message(e)
@@ -126,7 +137,7 @@ def assemble_query(aggregate_fields_and_functions, group_fields, columns, table_
126137
elif field in columns:
127138
if function.lower() == 'count':
128139
alias = f'_count'
129-
select_parts.append(f'COUNT(*) as {alias}')
140+
select_parts.append(f'COUNT(*) as "{alias}"')
130141
output_column_names.append(alias)
131142
else:
132143
# Sanitize function name and create alias
@@ -136,7 +147,7 @@ def assemble_query(aggregate_fields_and_functions, group_fields, columns, table_
136147
aggregate_function = function.upper()
137148

138149
alias = f'{field}_{function}'
139-
select_parts.append(f'{aggregate_function}("{field}") as {alias}')
150+
select_parts.append(f'{aggregate_function}("{field}") as "{alias}"')
140151
output_column_names.append(alias)
141152

142153
# Handle group fields
@@ -288,36 +299,36 @@ def get_table_data():
288299
def create_table():
289300
"""Create a new table from uploaded data"""
290301
try:
291-
if 'file' not in request.files:
292-
return jsonify({"status": "error", "message": "No file provided"}), 400
302+
if 'file' not in request.files and 'raw_data' not in request.form:
303+
return jsonify({"status": "error", "message": "No file or raw data provided"}), 400
293304

294-
file = request.files['file']
295305
table_name = request.form.get('table_name')
296-
297-
print(f"table_name: {table_name}")
298-
print(f"file: {file.filename}")
299-
print(f"file: {file}")
300-
301306
if not table_name:
302307
return jsonify({"status": "error", "message": "No table name provided"}), 400
303-
304-
# Sanitize table name:
305-
# 1. Convert to lowercase
306-
# 2. Replace hyphens with underscores
307-
# 3. Replace spaces with underscores
308-
# 4. Remove any other special characters
309-
sanitized_table_name = table_name.lower()
310-
sanitized_table_name = sanitized_table_name.replace('-', '_')
311-
sanitized_table_name = sanitized_table_name.replace(' ', '_')
312-
sanitized_table_name = ''.join(c for c in sanitized_table_name if c.isalnum() or c == '_')
313308

314-
# Ensure table name starts with a letter
315-
if not sanitized_table_name or not sanitized_table_name[0].isalpha():
316-
sanitized_table_name = 'table_' + sanitized_table_name
317-
318-
# Verify we have a valid table name after sanitization
319-
if not sanitized_table_name:
320-
return jsonify({"status": "error", "message": "Invalid table name"}), 400
309+
df = None
310+
if 'file' in request.files:
311+
file = request.files['file']
312+
# Read file based on extension
313+
if file.filename.endswith('.csv'):
314+
df = pd.read_csv(file)
315+
elif file.filename.endswith(('.xlsx', '.xls')):
316+
df = pd.read_excel(file)
317+
elif file.filename.endswith('.json'):
318+
df = pd.read_json(file)
319+
else:
320+
return jsonify({"status": "error", "message": "Unsupported file format"}), 400
321+
else:
322+
raw_data = request.form.get('raw_data')
323+
try:
324+
df = pd.DataFrame(json.loads(raw_data))
325+
except Exception as e:
326+
return jsonify({"status": "error", "message": f"Invalid JSON data: {str(e)}, it must be in the format of a list of dictionaries"}), 400
327+
328+
if df is None:
329+
return jsonify({"status": "error", "message": "No data provided"}), 400
330+
331+
sanitized_table_name = sanitize_table_name(table_name)
321332

322333
with db_manager.connection(session['session_id']) as db:
323334
# Check if table exists and generate unique name if needed
@@ -331,16 +342,6 @@ def create_table():
331342
# If exists, append counter to base name
332343
sanitized_table_name = f"{base_name}_{counter}"
333344
counter += 1
334-
335-
# Read file based on extension
336-
if file.filename.endswith('.csv'):
337-
df = pd.read_csv(file)
338-
elif file.filename.endswith(('.xlsx', '.xls')):
339-
df = pd.read_excel(file)
340-
elif file.filename.endswith('.json'):
341-
df = pd.read_json(file)
342-
else:
343-
return jsonify({"status": "error", "message": "Unsupported file format"}), 400
344345

345346
# Create table
346347
db.register('df_temp', df)
@@ -364,6 +365,8 @@ def create_table():
364365
"message": safe_msg
365366
}), status_code
366367

368+
369+
367370
@tables_bp.route('/delete-table', methods=['POST'])
368371
def drop_table():
369372
"""Drop a table or view"""
@@ -679,6 +682,29 @@ def analyze_table():
679682
"message": safe_msg
680683
}), status_code
681684

685+
def sanitize_table_name(table_name: str) -> str:
686+
"""
687+
Sanitize a table name to be a valid DuckDB table name.
688+
"""
689+
# Sanitize table name:
690+
# 1. Convert to lowercase
691+
# 2. Replace hyphens with underscores
692+
# 3. Replace spaces with underscores
693+
# 4. Remove any other special characters
694+
sanitized_table_name = table_name.lower()
695+
sanitized_table_name = sanitized_table_name.replace('-', '_')
696+
sanitized_table_name = sanitized_table_name.replace(' ', '_')
697+
sanitized_table_name = ''.join(c for c in sanitized_table_name if c.isalnum() or c == '_')
698+
699+
# Ensure table name starts with a letter
700+
if not sanitized_table_name or not sanitized_table_name[0].isalpha():
701+
sanitized_table_name = 'table_' + sanitized_table_name
702+
703+
# Verify we have a valid table name after sanitization
704+
if not sanitized_table_name:
705+
return f'table_{uuid.uuid4()}'
706+
return sanitized_table_name
707+
682708
def sanitize_db_error_message(error: Exception) -> Tuple[str, int]:
683709
"""
684710
Sanitize error messages before sending to client.

src/app/App.tsx

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,10 @@ const TableMenu: React.FC = () => {
197197
anchorEl={anchorEl}
198198
open={open}
199199
onClose={() => setAnchorEl(null)}
200-
MenuListProps={{
201-
'aria-labelledby': 'add-table-button',
202-
sx: { py: '4px', px: '8px' }
200+
slotProps={{
201+
paper: { sx: { py: '4px', px: '8px' } }
203202
}}
203+
aria-labelledby="add-table-button"
204204
sx={{ '& .MuiMenuItem-root': { padding: 0, margin: 0 } }}
205205
>
206206
<MenuItem onClick={(e) => {
@@ -247,16 +247,16 @@ const SessionMenu: React.FC = () => {
247247
anchorEl={anchorEl}
248248
open={open}
249249
onClose={() => setAnchorEl(null)}
250-
MenuListProps={{
251-
'aria-labelledby': 'session-menu-button',
252-
sx: { py: '4px', px: '8px' }
250+
slotProps={{
251+
paper: { sx: { py: '4px', px: '8px' } }
253252
}}
253+
aria-labelledby="session-menu-button"
254254
sx={{ '& .MuiMenuItem-root': { padding: 0, margin: 0 } }}
255255
>
256256
{sessionId && (
257257
<MenuItem disabled>
258-
<Typography sx={{ fontSize: 12, color: 'text.secondary', mx: 2 }}>
259-
ID: {sessionId}
258+
<Typography sx={{ fontSize: 12, color: 'text.secondary'}}>
259+
session id: {sessionId}
260260
</Typography>
261261
</MenuItem>
262262
)}

src/app/dfSlice.tsx

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ import { getDataTable } from '../views/VisualizationView';
1111
import { adaptChart, getTriggers, getUrls } from './utils';
1212
import { Type } from '../data/types';
1313
import { TableChallenges } from '../views/TableSelectionView';
14-
import { inferTypeFromValueArray } from '../data/utils';
14+
import { createTableFromFromObjectArray, inferTypeFromValueArray } from '../data/utils';
15+
import { handleSSEMessage } from './SSEActions';
1516

1617
enableMapSet();
1718

@@ -27,7 +28,7 @@ export const generateFreshChart = (tableRef: string, chartType?: string) : Chart
2728
}
2829

2930
export interface SSEMessage {
30-
type: "notification" | "action";
31+
type: "heartbeat" | "notification" | "action";
3132
text: string;
3233
data?: Record<string, any>;
3334
timestamp: number;
@@ -81,7 +82,7 @@ export interface DataFormulatorState {
8182

8283
dataLoaderConnectParams: Record<string, Record<string, string>>; // {table_name: {param_name: param_value}}
8384

84-
lastSSEMessage: SSEMessage | undefined; // Store the last received SSE message
85+
pendingSSEActions: SSEMessage[]; // Actions taken by the server but not yet completed
8586
}
8687

8788
// Define the initial state using that type
@@ -123,7 +124,7 @@ const initialState: DataFormulatorState = {
123124

124125
dataLoaderConnectParams: {},
125126

126-
lastSSEMessage: undefined,
127+
pendingSSEActions: [],
127128
}
128129

129130
let getUnrefedDerivedTableIds = (state: DataFormulatorState) => {
@@ -768,24 +769,7 @@ export const dataFormulatorSlice = createSlice({
768769
delete state.dataLoaderConnectParams[dataLoaderType];
769770
},
770771
handleSSEMessage: (state, action: PayloadAction<SSEMessage>) => {
771-
state.lastSSEMessage = action.payload;
772-
if (action.payload.type == "notification") {
773-
console.log('SSE message stored in Redux:', action.payload);
774-
state.messages = [...state.messages, {
775-
component: "server",
776-
type: "info",
777-
timestamp: action.payload.timestamp,
778-
value: action.payload.text || "Unknown message"
779-
}];
780-
} else if (action.payload.type == "action") {
781-
console.log('SSE message stored in Redux:', action.payload);
782-
state.messages = [...state.messages, {
783-
component: "server",
784-
type: "info",
785-
timestamp: action.payload.timestamp,
786-
value: action.payload.text || "Unknown message"
787-
}];
788-
}
772+
handleSSEMessage(state, action.payload);
789773
},
790774
clearMessages: (state) => {
791775
state.messages = [];

0 commit comments

Comments
 (0)