1
+ import json
2
+ import pandas as pd
3
+ import duckdb
4
+ import os
5
+
6
+ from data_formulator .data_loader .external_data_loader import ExternalDataLoader , sanitize_table_name
7
+ from typing import Dict , Any , List
8
+
9
+ class S3DataLoader (ExternalDataLoader ):
10
+
11
+ @staticmethod
12
+ def list_params () -> List [Dict [str , Any ]]:
13
+ params_list = [
14
+ {"name" : "aws_access_key_id" , "type" : "string" , "required" : True , "default" : "" , "description" : "AWS access key ID" },
15
+ {"name" : "aws_secret_access_key" , "type" : "string" , "required" : True , "default" : "" , "description" : "AWS secret access key" },
16
+ {"name" : "aws_session_token" , "type" : "string" , "required" : False , "default" : "" , "description" : "AWS session token (required for temporary credentials)" },
17
+ {"name" : "region_name" , "type" : "string" , "required" : True , "default" : "us-east-1" , "description" : "AWS region name" },
18
+ {"name" : "bucket" , "type" : "string" , "required" : True , "default" : "" , "description" : "S3 bucket name" }
19
+ ]
20
+ return params_list
21
+
22
+ def __init__ (self , params : Dict [str , Any ], duck_db_conn : duckdb .DuckDBPyConnection ):
23
+ self .params = params
24
+ self .duck_db_conn = duck_db_conn
25
+
26
+ # Extract parameters
27
+ self .aws_access_key_id = params .get ("aws_access_key_id" , "" )
28
+ self .aws_secret_access_key = params .get ("aws_secret_access_key" , "" )
29
+ self .aws_session_token = params .get ("aws_session_token" , "" )
30
+ self .region_name = params .get ("region_name" , "us-east-1" )
31
+ self .bucket = params .get ("bucket" , "" )
32
+
33
+ # Install and load the httpfs extension for S3 access
34
+ self .duck_db_conn .install_extension ("httpfs" )
35
+ self .duck_db_conn .load_extension ("httpfs" )
36
+
37
+ # Set AWS credentials for DuckDB
38
+ self .duck_db_conn .execute (f"SET s3_region='{ self .region_name } '" )
39
+ self .duck_db_conn .execute (f"SET s3_access_key_id='{ self .aws_access_key_id } '" )
40
+ self .duck_db_conn .execute (f"SET s3_secret_access_key='{ self .aws_secret_access_key } '" )
41
+ if self .aws_session_token : # Add this block
42
+ self .duck_db_conn .execute (f"SET s3_session_token='{ self .aws_session_token } '" )
43
+
44
+ def list_tables (self ) -> List [Dict [str , Any ]]:
45
+ # Use boto3 to list objects in the bucket
46
+ import boto3
47
+
48
+ s3_client = boto3 .client (
49
+ 's3' ,
50
+ aws_access_key_id = self .aws_access_key_id ,
51
+ aws_secret_access_key = self .aws_secret_access_key ,
52
+ aws_session_token = self .aws_session_token if self .aws_session_token else None ,
53
+ region_name = self .region_name
54
+ )
55
+
56
+ # List objects in the bucket
57
+ response = s3_client .list_objects_v2 (Bucket = self .bucket )
58
+
59
+ results = []
60
+
61
+ if 'Contents' in response :
62
+ for obj in response ['Contents' ]:
63
+ key = obj ['Key' ]
64
+
65
+ # Skip directories and non-data files
66
+ if key .endswith ('/' ) or not self ._is_supported_file (key ):
67
+ continue
68
+
69
+ # Create S3 URL
70
+ s3_url = f"s3://{ self .bucket } /{ key } "
71
+
72
+ try :
73
+ # Choose the appropriate read function based on file extension
74
+ if s3_url .lower ().endswith ('.parquet' ):
75
+ sample_df = self .duck_db_conn .execute (f"SELECT * FROM read_parquet('{ s3_url } ') LIMIT 10" ).df ()
76
+ elif s3_url .lower ().endswith ('.json' ) or s3_url .lower ().endswith ('.jsonl' ):
77
+ sample_df = self .duck_db_conn .execute (f"SELECT * FROM read_json_auto('{ s3_url } ') LIMIT 10" ).df ()
78
+ elif s3_url .lower ().endswith ('.csv' ): # Default to CSV for other formats
79
+ sample_df = self .duck_db_conn .execute (f"SELECT * FROM read_csv_auto('{ s3_url } ') LIMIT 10" ).df ()
80
+
81
+ # Get column information
82
+ columns = [{
83
+ 'name' : col ,
84
+ 'type' : str (sample_df [col ].dtype )
85
+ } for col in sample_df .columns ]
86
+
87
+ # Get sample data
88
+ sample_rows = json .loads (sample_df .to_json (orient = "records" ))
89
+
90
+ # Estimate row count (this is approximate for CSV files)
91
+ row_count = self ._estimate_row_count (s3_url )
92
+
93
+ table_metadata = {
94
+ "row_count" : row_count ,
95
+ "columns" : columns ,
96
+ "sample_rows" : sample_rows
97
+ }
98
+
99
+ results .append ({
100
+ "name" : s3_url ,
101
+ "metadata" : table_metadata
102
+ })
103
+ except Exception as e :
104
+ # Skip files that can't be read
105
+ print (f"Error reading { s3_url } : { e } " )
106
+ continue
107
+
108
+ return results
109
+
110
+ def _is_supported_file (self , key : str ) -> bool :
111
+ """Check if the file type is supported by DuckDB."""
112
+ supported_extensions = ['.csv' , '.parquet' , '.json' , '.jsonl' ]
113
+ return any (key .lower ().endswith (ext ) for ext in supported_extensions )
114
+
115
+ def _estimate_row_count (self , s3_url : str ) -> int :
116
+ """Estimate the number of rows in a file."""
117
+ try :
118
+ # For parquet files, we can get the exact count
119
+ if s3_url .lower ().endswith ('.parquet' ):
120
+ count = self .duck_db_conn .execute (f"SELECT COUNT(*) FROM read_parquet('{ s3_url } ')" ).fetchone ()[0 ]
121
+ return count
122
+
123
+ # For CSV files, we'll sample the file to estimate size
124
+ sample_size = 1000
125
+ sample_df = self .duck_db_conn .execute (f"SELECT * FROM read_csv_auto('{ s3_url } ') LIMIT { sample_size } " ).df ()
126
+
127
+ # Get file size from S3
128
+ import boto3
129
+ s3_client = boto3 .client (
130
+ 's3' ,
131
+ aws_access_key_id = self .aws_access_key_id ,
132
+ aws_secret_access_key = self .aws_secret_access_key ,
133
+ aws_session_token = self .aws_session_token if self .aws_session_token else None ,
134
+ region_name = self .region_name
135
+ )
136
+
137
+ key = s3_url .replace (f"s3://{ self .bucket } /" , "" )
138
+ response = s3_client .head_object (Bucket = self .bucket , Key = key )
139
+ file_size = response ['ContentLength' ]
140
+
141
+ # Estimate based on sample size and file size
142
+ if len (sample_df ) > 0 :
143
+ # Calculate average row size in bytes
144
+ avg_row_size = file_size / len (sample_df )
145
+ estimated_rows = int (file_size / avg_row_size )
146
+ return min (estimated_rows , 1000000 ) # Cap at 1 million for UI performance
147
+
148
+ return 0
149
+ except Exception as e :
150
+ print (f"Error estimating row count for { s3_url } : { e } " )
151
+ return 0
152
+
153
+ def ingest_data (self , table_name : str , name_as : str = None , size : int = 1000000 ):
154
+ if name_as is None :
155
+ name_as = table_name .split ('/' )[- 1 ].split ('.' )[0 ]
156
+
157
+ name_as = sanitize_table_name (name_as )
158
+
159
+ # Determine file type and use appropriate DuckDB function
160
+ if table_name .lower ().endswith ('.csv' ):
161
+ self .duck_db_conn .execute (f"""
162
+ CREATE OR REPLACE TABLE main.{ name_as } AS
163
+ SELECT * FROM read_csv_auto('{ table_name } ')
164
+ LIMIT { size }
165
+ """ )
166
+ elif table_name .lower ().endswith ('.parquet' ):
167
+ self .duck_db_conn .execute (f"""
168
+ CREATE OR REPLACE TABLE main.{ name_as } AS
169
+ SELECT * FROM read_parquet('{ table_name } ')
170
+ LIMIT { size }
171
+ """ )
172
+ elif table_name .lower ().endswith ('.json' ) or table_name .lower ().endswith ('.jsonl' ):
173
+ self .duck_db_conn .execute (f"""
174
+ CREATE OR REPLACE TABLE main.{ name_as } AS
175
+ SELECT * FROM read_json_auto('{ table_name } ')
176
+ LIMIT { size }
177
+ """ )
178
+ else :
179
+ raise ValueError (f"Unsupported file type: { table_name } " )
180
+
181
+ def view_query_sample (self , query : str ) -> List [Dict [str , Any ]]:
182
+ return self .duck_db_conn .execute (query ).df ().head (10 ).to_dict (orient = "records" )
183
+
184
+ def ingest_data_from_query (self , query : str , name_as : str ):
185
+ # Execute the query and get results as a DataFrame
186
+ df = self .duck_db_conn .execute (query ).df ()
187
+ # Use the base class's method to ingest the DataFrame
188
+ self .ingest_df_to_duckdb (df , name_as )
0 commit comments