Skip to content

Commit 44b6373

Browse files
committed
supporting connection to mysql and pgsql
1 parent 56c3ce7 commit 44b6373

File tree

6 files changed

+195
-17
lines changed

6 files changed

+195
-17
lines changed

.env.template

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,15 @@
33
# python -m data_formulator -p 5000 --exec-python-in-subprocess true --disable-display-keys true
44

55
DISABLE_DISPLAY_KEYS=false # if true, the display keys will not be shown in the frontend
6-
EXEC_PYTHON_IN_SUBPROCESS=false # if true, the python code will be executed in a subprocess to avoid crashing the main app, but it will increase the time of response
6+
EXEC_PYTHON_IN_SUBPROCESS=false # if true, the python code will be executed in a subprocess to avoid crashing the main app, but it will increase the time of response
7+
8+
# External atabase connection settings
9+
# check https://duckdb.org/docs/stable/extensions/mysql.html and https://duckdb.org/docs/stable/extensions/postgres.html
10+
USE_EXTERNAL_DB=false # if true, the app will use an external database instead of the one in the app
11+
DB_NAME=mysql_db # the name to refer to this database connection
12+
DB_TYPE=mysql # mysql or postgresql
13+
DB_HOST=localhost
14+
DB_PORT=0
15+
DB_DATABASE=mysql
16+
DB_USER=root
17+
DB_PASSWORD=

py-src/data_formulator/app.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,20 @@
1515

1616
import webbrowser
1717
import threading
18+
import numpy as np
19+
import datetime
20+
import time
1821

1922
import logging
2023

2124
import json
22-
import time
2325
from pathlib import Path
2426

2527
from vega_datasets import data as vega_data
2628

2729
from dotenv import load_dotenv
2830
import secrets
29-
31+
import base64
3032
APP_ROOT = Path(os.path.join(Path(__file__).parent)).absolute()
3133

3234
import os
@@ -38,6 +40,16 @@
3840
app = Flask(__name__, static_url_path='', static_folder=os.path.join(APP_ROOT, "dist"))
3941
app.secret_key = secrets.token_hex(16) # Generate a random secret key for sessions
4042

43+
class CustomJSONEncoder(json.JSONEncoder):
44+
def default(self, obj):
45+
if isinstance(obj, np.int64):
46+
return int(obj)
47+
if isinstance(obj, (bytes, bytearray)):
48+
return base64.b64encode(obj).decode('ascii')
49+
return super().default(obj)
50+
51+
app.json_encoder = CustomJSONEncoder
52+
4153
# Load env files early
4254
load_dotenv(os.path.join(APP_ROOT, "..", "..", 'api-keys.env'))
4355
load_dotenv(os.path.join(APP_ROOT, 'api-keys.env'))

py-src/data_formulator/db_manager.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
import duckdb
22
import pandas as pd
3-
from typing import Optional, Dict, List, ContextManager
3+
from typing import Optional, Dict, List, ContextManager, Any, Tuple
44
import time
55
from flask import session
66
import tempfile
77
import os
88
from contextlib import contextmanager
9+
from dotenv import load_dotenv
910

1011
class DuckDBManager:
11-
def __init__(self):
12+
def __init__(self, external_db_connections: Dict[str, Dict[str, Any]]):
1213
# Store session db file paths
1314
self._db_files: Dict[str, str] = {}
15+
# Track which extensions have been installed for which db files
16+
self._installed_extensions: Dict[str, List[str]] = {}
17+
18+
# external db connections
19+
self._external_db_connections: Dict[str, Dict[str, Any]] = external_db_connections
1420

1521
@contextmanager
1622
def connection(self, session_id: str) -> ContextManager[duckdb.DuckDBPyConnection]:
@@ -31,13 +37,49 @@ def get_connection(self, session_id: str) -> duckdb.DuckDBPyConnection:
3137
db_file = os.path.join(tempfile.gettempdir(), f"df_{session_id}.db")
3238
print(f"=== Creating new db file: {db_file}")
3339
self._db_files[session_id] = db_file
40+
# Initialize extension tracking for this file
41+
self._installed_extensions[db_file] = []
3442
else:
3543
print(f"=== Using existing db file: {self._db_files[session_id]}")
3644
db_file = self._db_files[session_id]
3745

3846
# Create a fresh connection to the database file
3947
conn = duckdb.connect(database=db_file)
48+
49+
if self._external_db_connections and self._external_db_connections['db_type'] in ['mysql', 'postgresql']:
50+
db_name = self._external_db_connections['db_name']
51+
db_type = self._external_db_connections['db_type']
52+
53+
print(f"=== connecting to {db_type} extension")
54+
# Only install if not already installed for this db file
55+
if db_type not in self._installed_extensions.get(db_file, []):
56+
conn.execute(f"INSTALL {db_type};")
57+
self._installed_extensions[db_file].append(db_type)
58+
59+
conn.execute(f"LOAD {db_type};")
60+
conn.execute(f"""CREATE SECRET (
61+
TYPE {db_type},
62+
HOST '{self._external_db_connections['host']}',
63+
PORT '{self._external_db_connections['port']}',
64+
DATABASE '{self._external_db_connections['database']}',
65+
USER '{self._external_db_connections['user']}',
66+
PASSWORD '{self._external_db_connections['password']}');
67+
""")
68+
conn.execute(f"ATTACH '' AS {db_name} (TYPE {db_type});")
69+
# result = conn.execute(f"SELECT * FROM {db_name}.information_schema.tables WHERE table_schema NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys');").fetch_df()
70+
# print(f"=== result: {result}")
71+
4072
return conn
4173

74+
env = load_dotenv()
75+
4276
# Initialize the DB manager
43-
db_manager = DuckDBManager()
77+
db_manager = DuckDBManager({
78+
"db_name": os.getenv('DB_NAME'),
79+
"db_type": os.getenv('DB_TYPE'),
80+
"host": os.getenv('DB_HOST'),
81+
"port": os.getenv('DB_PORT'),
82+
"database": os.getenv('DB_DATABASE'),
83+
"user": os.getenv('DB_USER'),
84+
"password": os.getenv('DB_PASSWORD')
85+
} if os.getenv('USE_EXTERNAL_DB') == 'true' else None)

py-src/data_formulator/tables_routes.py

Lines changed: 115 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
import mimetypes
88
mimetypes.add_type('application/javascript', '.js')
99
mimetypes.add_type('application/javascript', '.mjs')
10+
import json
1011

1112
from flask import request, send_from_directory, session, jsonify, Blueprint
1213
import pandas as pd
13-
14+
import random
15+
import string
1416
from pathlib import Path
1517

1618
from data_formulator.db_manager import db_manager
@@ -39,15 +41,18 @@ def list_tables():
3941
try:
4042
result = []
4143
with db_manager.connection(session['session_id']) as db:
42-
tables = db.execute("SHOW TABLES").fetchall()
44+
table_metadata_list = db.execute("SELECT database_name, schema_name, table_name, schema_name==current_schema() as is_current_schema FROM duckdb_tables() WHERE internal=False").fetchall()
4345

44-
for table in tables:
45-
table_name = table[0]
46+
print(f"table_metadata_list: {table_metadata_list}")
47+
for table_metadata in table_metadata_list:
48+
[database_name, schema_name, table_name, is_current_schema] = table_metadata
49+
50+
table_name = table_name if is_current_schema else '.'.join([database_name, schema_name, table_name])
4651
# Get column information
4752
columns = db.execute(f"DESCRIBE {table_name}").fetchall()
4853
# Get row count
4954
row_count = db.execute(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0]
50-
sample_rows = db.execute(f"SELECT * FROM {table_name} LIMIT 1000").fetchall()
55+
sample_rows = db.execute(f"SELECT * FROM {table_name} LIMIT 1000").fetchdf()
5156

5257
# Check if this is a view or a table
5358
try:
@@ -62,7 +67,7 @@ def list_tables():
6267
"name": table_name,
6368
"columns": [{"name": col[0], "type": col[1]} for col in columns],
6469
"row_count": row_count,
65-
"sample_rows": [dict(zip([col[0] for col in columns], row)) for row in sample_rows],
70+
"sample_rows": json.loads(sample_rows.to_json(orient='records')),
6671
"view_source": view_source
6772
})
6873

@@ -186,11 +191,11 @@ def sample_table():
186191
else:
187192
query += f" ORDER BY ROWID DESC LIMIT {sample_size}"
188193

189-
result = db.execute(query).fetchall()
194+
result = db.execute(query).fetchdf()
190195

191196
return jsonify({
192197
"status": "success",
193-
"rows": [dict(zip(output_column_names, row)) for row in result],
198+
"rows": json.loads(result.to_json(orient='records')),
194199
"total_row_count": total_row_count
195200
})
196201
except Exception as e:
@@ -437,6 +442,108 @@ def upload_db_file():
437442
"message": safe_msg
438443
}), status_code
439444

445+
446+
def validate_db_connection_params(db_type: str, db_host: str, db_port: int,
447+
db_database: str, db_user: str, db_password: str):
448+
"""Validate database connection parameters"""
449+
# Validate db_type
450+
valid_db_types = ['postgresql', 'mysql']
451+
if not db_type or db_type.lower() not in valid_db_types:
452+
raise ValueError(f"Invalid database type. Must be one of: {', '.join(valid_db_types)}")
453+
454+
# Validate host (basic DNS/IP format check)
455+
if not db_host or not re.match(r'^[a-zA-Z0-9.-]+$', db_host):
456+
raise ValueError("Invalid host format")
457+
458+
# Validate port
459+
try:
460+
port = int(db_port)
461+
if not (1 <= port <= 65535):
462+
raise ValueError()
463+
except (ValueError, TypeError):
464+
raise ValueError("Port must be a number between 1 and 65535")
465+
466+
# Validate database name (alphanumeric and underscores only)
467+
if not db_database or not re.match(r'^[a-zA-Z0-9_]+$', db_database):
468+
raise ValueError("Invalid database name format")
469+
470+
# Validate username (alphanumeric and some special chars)
471+
if not db_user or not re.match(r'^[a-zA-Z0-9@._-]+$', db_user):
472+
raise ValueError("Invalid username format")
473+
474+
# Validate password exists
475+
if not db_password:
476+
raise ValueError("Password cannot be empty")
477+
478+
@tables_bp.route('/attach-external-db', methods=['POST'])
479+
def attach_external_db():
480+
"""Attach an external db to the session"""
481+
try:
482+
data = request.get_json()
483+
db_type = data.get('db_type')
484+
db_host = data.get('db_host')
485+
db_port = data.get('db_port')
486+
db_database = data.get('db_database')
487+
db_user = data.get('db_user')
488+
db_password = data.get('db_password')
489+
490+
# Generate a random suffix for the database name
491+
suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=2))
492+
db_name = f"{db_type.lower()}_{suffix}"
493+
494+
if 'session_id' not in session:
495+
return jsonify({"status": "error", "message": "No session ID found"}), 400
496+
497+
with db_manager.connection(session['session_id']) as conn:
498+
# Create secret using parameterized query
499+
500+
# Install and load the extension
501+
if db_type == 'mysql':
502+
conn.install_extension("mysql")
503+
conn.load_extension("mysql")
504+
elif db_type == 'postgresql':
505+
conn.install_extension("postgres")
506+
conn.load_extension("postgres")
507+
508+
connect_query = f"""CREATE SECRET (
509+
TYPE {db_type},
510+
HOST '{db_host}',
511+
PORT '{db_port}',
512+
DATABASE '{db_database}',
513+
USER '{db_user}',
514+
PASSWORD '{db_password}'
515+
);"""
516+
conn.execute(connect_query)
517+
518+
# Attach the database
519+
conn.execute(f"ATTACH '' AS {db_name} (TYPE {db_type});")
520+
521+
result = conn.execute(f"SELECT * FROM {db_name}.information_schema.tables WHERE table_schema NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys');").fetch_df()
522+
523+
print(f"result: {result}")
524+
525+
result = conn.execute(f"SELECT * FROM {db_name}.sakila.actor LIMIT 10;").fetchdf()
526+
527+
print(f"result: {result}")
528+
529+
# Log what we found for debugging
530+
logger.info(f"Found {len(result)} tables: {result}")
531+
532+
return jsonify({
533+
"status": "success",
534+
"message": "External database attached successfully",
535+
"result": result
536+
})
537+
538+
except Exception as e:
539+
logger.error(f"Error attaching external database: {str(e)}")
540+
safe_msg, status_code = sanitize_db_error_message(e)
541+
return jsonify({
542+
"status": "error",
543+
"message": safe_msg
544+
}), status_code
545+
546+
440547
@tables_bp.route('/download-db-file', methods=['GET'])
441548
def download_db_file():
442549
"""Download the db file for a session"""

src/app/utils.tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ export function getUrls() {
5757
UPLOAD_DB_FILE: `/api/tables/upload-db-file`,
5858
DOWNLOAD_DB_FILE: `/api/tables/download-db-file`,
5959
RESET_DB_FILE: `/api/tables/reset-db-file`,
60+
ATTACH_EXTERNAL_DB: `/api/tables/attach-external-db`,
6061

6162
LIST_TABLES: `/api/tables/list-tables`,
6263
TABLE_DATA: `/api/tables/get-table`,

src/views/DBTableManager.tsx

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ import {
2929
TableRow,
3030
CircularProgress,
3131
ButtonGroup,
32-
Tooltip
32+
Tooltip,
33+
MenuItem
3334
} from '@mui/material';
3435
import DeleteIcon from '@mui/icons-material/Delete';
3536
import UploadFileIcon from '@mui/icons-material/UploadFile';
@@ -548,7 +549,7 @@ export const DBTableSelectionDialog: React.FC<{ buttonElement: any }> = function
548549
{...a11yProps(i)}
549550
/>
550551
))}
551-
</Tabs>
552+
</Tabs>
552553
<Divider sx={{my: 1}} textAlign='left'> <TuneIcon sx={{fontSize: 12, color: "text.secondary"}} /></Divider>
553554
{uploadFileButton(<Typography component="span" fontSize={12}>{isUploading ? 'uploading...' : 'upload file'}</Typography>)}
554555
</Box>
@@ -595,7 +596,11 @@ export const DBTableSelectionDialog: React.FC<{ buttonElement: any }> = function
595596
/>
596597
) : (
597598
<CustomReactTable
598-
rows={currentTable.sample_rows.slice(0, 9)}
599+
rows={currentTable.sample_rows.map((row: any) => {
600+
return Object.fromEntries(Object.entries(row).map(([key, value]: [string, any]) => {
601+
return [key, String(value)];
602+
}));
603+
}).slice(0, 9)}
599604
columnDefs={currentTable.columns.map(col => ({
600605
id: col.name,
601606
label: col.name,

0 commit comments

Comments
 (0)