Skip to content

Commit e8799a6

Browse files
authored
Merge pull request #159 from slackroo/main
"Add S3 data loader support to DBTableManager and data formulator"
2 parents e910fe4 + 48f4ee2 commit e8799a6

File tree

4 files changed

+197
-4
lines changed

4 files changed

+197
-4
lines changed
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from data_formulator.data_loader.external_data_loader import ExternalDataLoader
22
from data_formulator.data_loader.mysql_data_loader import MySQLDataLoader
33
from data_formulator.data_loader.kusto_data_loader import KustoDataLoader
4+
from data_formulator.data_loader.s3_data_loader import S3DataLoader
45

56
DATA_LOADERS = {
67
"mysql": MySQLDataLoader,
7-
"kusto": KustoDataLoader
8+
"kusto": KustoDataLoader,
9+
"s3": S3DataLoader,
810
}
911

10-
__all__ = ["ExternalDataLoader", "MySQLDataLoader", "KustoDataLoader", "DATA_LOADERS"]
12+
__all__ = ["ExternalDataLoader", "MySQLDataLoader", "KustoDataLoader", "S3DataLoader", "DATA_LOADERS"]
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import json
2+
import pandas as pd
3+
import duckdb
4+
import os
5+
6+
from data_formulator.data_loader.external_data_loader import ExternalDataLoader, sanitize_table_name
7+
from typing import Dict, Any, List
8+
9+
class S3DataLoader(ExternalDataLoader):
10+
11+
@staticmethod
12+
def list_params() -> List[Dict[str, Any]]:
13+
params_list = [
14+
{"name": "aws_access_key_id", "type": "string", "required": True, "default": "", "description": "AWS access key ID"},
15+
{"name": "aws_secret_access_key", "type": "string", "required": True, "default": "", "description": "AWS secret access key"},
16+
{"name": "aws_session_token", "type": "string", "required": False, "default": "", "description": "AWS session token (required for temporary credentials)"},
17+
{"name": "region_name", "type": "string", "required": True, "default": "us-east-1", "description": "AWS region name"},
18+
{"name": "bucket", "type": "string", "required": True, "default": "", "description": "S3 bucket name"}
19+
]
20+
return params_list
21+
22+
def __init__(self, params: Dict[str, Any], duck_db_conn: duckdb.DuckDBPyConnection):
23+
self.params = params
24+
self.duck_db_conn = duck_db_conn
25+
26+
# Extract parameters
27+
self.aws_access_key_id = params.get("aws_access_key_id", "")
28+
self.aws_secret_access_key = params.get("aws_secret_access_key", "")
29+
self.aws_session_token = params.get("aws_session_token", "")
30+
self.region_name = params.get("region_name", "us-east-1")
31+
self.bucket = params.get("bucket", "")
32+
33+
# Install and load the httpfs extension for S3 access
34+
self.duck_db_conn.install_extension("httpfs")
35+
self.duck_db_conn.load_extension("httpfs")
36+
37+
# Set AWS credentials for DuckDB
38+
self.duck_db_conn.execute(f"SET s3_region='{self.region_name}'")
39+
self.duck_db_conn.execute(f"SET s3_access_key_id='{self.aws_access_key_id}'")
40+
self.duck_db_conn.execute(f"SET s3_secret_access_key='{self.aws_secret_access_key}'")
41+
if self.aws_session_token: # Add this block
42+
self.duck_db_conn.execute(f"SET s3_session_token='{self.aws_session_token}'")
43+
44+
def list_tables(self) -> List[Dict[str, Any]]:
45+
# Use boto3 to list objects in the bucket
46+
import boto3
47+
48+
s3_client = boto3.client(
49+
's3',
50+
aws_access_key_id=self.aws_access_key_id,
51+
aws_secret_access_key=self.aws_secret_access_key,
52+
aws_session_token=self.aws_session_token if self.aws_session_token else None,
53+
region_name=self.region_name
54+
)
55+
56+
# List objects in the bucket
57+
response = s3_client.list_objects_v2(Bucket=self.bucket)
58+
59+
results = []
60+
61+
if 'Contents' in response:
62+
for obj in response['Contents']:
63+
key = obj['Key']
64+
65+
# Skip directories and non-data files
66+
if key.endswith('/') or not self._is_supported_file(key):
67+
continue
68+
69+
# Create S3 URL
70+
s3_url = f"s3://{self.bucket}/{key}"
71+
72+
try:
73+
# Choose the appropriate read function based on file extension
74+
if s3_url.lower().endswith('.parquet'):
75+
sample_df = self.duck_db_conn.execute(f"SELECT * FROM read_parquet('{s3_url}') LIMIT 10").df()
76+
elif s3_url.lower().endswith('.json') or s3_url.lower().endswith('.jsonl'):
77+
sample_df = self.duck_db_conn.execute(f"SELECT * FROM read_json_auto('{s3_url}') LIMIT 10").df()
78+
elif s3_url.lower().endswith('.csv'): # Default to CSV for other formats
79+
sample_df = self.duck_db_conn.execute(f"SELECT * FROM read_csv_auto('{s3_url}') LIMIT 10").df()
80+
81+
# Get column information
82+
columns = [{
83+
'name': col,
84+
'type': str(sample_df[col].dtype)
85+
} for col in sample_df.columns]
86+
87+
# Get sample data
88+
sample_rows = json.loads(sample_df.to_json(orient="records"))
89+
90+
# Estimate row count (this is approximate for CSV files)
91+
row_count = self._estimate_row_count(s3_url)
92+
93+
table_metadata = {
94+
"row_count": row_count,
95+
"columns": columns,
96+
"sample_rows": sample_rows
97+
}
98+
99+
results.append({
100+
"name": s3_url,
101+
"metadata": table_metadata
102+
})
103+
except Exception as e:
104+
# Skip files that can't be read
105+
print(f"Error reading {s3_url}: {e}")
106+
continue
107+
108+
return results
109+
110+
def _is_supported_file(self, key: str) -> bool:
111+
"""Check if the file type is supported by DuckDB."""
112+
supported_extensions = ['.csv', '.parquet', '.json', '.jsonl']
113+
return any(key.lower().endswith(ext) for ext in supported_extensions)
114+
115+
def _estimate_row_count(self, s3_url: str) -> int:
116+
"""Estimate the number of rows in a file."""
117+
try:
118+
# For parquet files, we can get the exact count
119+
if s3_url.lower().endswith('.parquet'):
120+
count = self.duck_db_conn.execute(f"SELECT COUNT(*) FROM read_parquet('{s3_url}')").fetchone()[0]
121+
return count
122+
123+
# For CSV files, we'll sample the file to estimate size
124+
sample_size = 1000
125+
sample_df = self.duck_db_conn.execute(f"SELECT * FROM read_csv_auto('{s3_url}') LIMIT {sample_size}").df()
126+
127+
# Get file size from S3
128+
import boto3
129+
s3_client = boto3.client(
130+
's3',
131+
aws_access_key_id=self.aws_access_key_id,
132+
aws_secret_access_key=self.aws_secret_access_key,
133+
aws_session_token=self.aws_session_token if self.aws_session_token else None,
134+
region_name=self.region_name
135+
)
136+
137+
key = s3_url.replace(f"s3://{self.bucket}/", "")
138+
response = s3_client.head_object(Bucket=self.bucket, Key=key)
139+
file_size = response['ContentLength']
140+
141+
# Estimate based on sample size and file size
142+
if len(sample_df) > 0:
143+
# Calculate average row size in bytes
144+
avg_row_size = file_size / len(sample_df)
145+
estimated_rows = int(file_size / avg_row_size)
146+
return min(estimated_rows, 1000000) # Cap at 1 million for UI performance
147+
148+
return 0
149+
except Exception as e:
150+
print(f"Error estimating row count for {s3_url}: {e}")
151+
return 0
152+
153+
def ingest_data(self, table_name: str, name_as: str = None, size: int = 1000000):
154+
if name_as is None:
155+
name_as = table_name.split('/')[-1].split('.')[0]
156+
157+
name_as = sanitize_table_name(name_as)
158+
159+
# Determine file type and use appropriate DuckDB function
160+
if table_name.lower().endswith('.csv'):
161+
self.duck_db_conn.execute(f"""
162+
CREATE OR REPLACE TABLE main.{name_as} AS
163+
SELECT * FROM read_csv_auto('{table_name}')
164+
LIMIT {size}
165+
""")
166+
elif table_name.lower().endswith('.parquet'):
167+
self.duck_db_conn.execute(f"""
168+
CREATE OR REPLACE TABLE main.{name_as} AS
169+
SELECT * FROM read_parquet('{table_name}')
170+
LIMIT {size}
171+
""")
172+
elif table_name.lower().endswith('.json') or table_name.lower().endswith('.jsonl'):
173+
self.duck_db_conn.execute(f"""
174+
CREATE OR REPLACE TABLE main.{name_as} AS
175+
SELECT * FROM read_json_auto('{table_name}')
176+
LIMIT {size}
177+
""")
178+
else:
179+
raise ValueError(f"Unsupported file type: {table_name}")
180+
181+
def view_query_sample(self, query: str) -> List[Dict[str, Any]]:
182+
return self.duck_db_conn.execute(query).df().head(10).to_dict(orient="records")
183+
184+
def ingest_data_from_query(self, query: str, name_as: str):
185+
# Execute the query and get results as a DataFrame
186+
df = self.duck_db_conn.execute(query).df()
187+
# Use the base class's method to ingest the DataFrame
188+
self.ingest_df_to_duckdb(df, name_as)

requirements.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@ openai
88
azure-identity
99
azure-kusto-data
1010
azure-keyvault-secrets
11+
azure-kusto-data
12+
azure-storage-blob
1113
python-dotenv
1214
vega_datasets
1315
litellm
1416
duckdb
15-
-e . #also need to install data formulator itself
17+
boto3
18+
-e . #also need to install data formulator itself

src/views/DBTableManager.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ export const DBTableSelectionDialog: React.FC<{ buttonElement: any }> = function
653653
sx={{px: 0.5}}
654654
>
655655
<Typography variant="caption" sx={{color: "text.secondary", fontWeight: "bold", px: 1}}>connect external data</Typography>
656-
{["file upload", "mysql", "kusto"].map((dataLoaderType, i) => (
656+
{["file upload", "mysql", "kusto","s3"].map((dataLoaderType, i) => (
657657
<Tab
658658
key={`dataLoader:${dataLoaderType}`}
659659
wrapped

0 commit comments

Comments
 (0)