1- from typing import TYPE_CHECKING , Union , List , Dict , Tuple , Any
1+ from typing import TYPE_CHECKING , Union , List , Dict , Tuple , Any , Optional
22from logging import getLogger , Logger , INFO
33import json
44import warnings
@@ -137,6 +137,7 @@ def load_table(dataframe: pd.DataFrame,
137137 table_name : str ,
138138 connection : Any ,
139139 num_files : int ,
140+ columns : Optional [List [str ]] = None ,
140141 mode : str = "append" ,
141142 preserve_index : bool = False ,
142143 engine : str = "mysql" ,
@@ -152,6 +153,7 @@ def load_table(dataframe: pd.DataFrame,
152153 :param table_name: Aurora table name
153154 :param connection: A PEP 249 compatible connection (Can be generated with Aurora.generate_connection())
154155 :param num_files: Number of files to be loaded
156+ :param columns: List of columns to load
155157 :param mode: append or overwrite
156158 :param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
157159 :param engine: "mysql" or "postgres"
@@ -167,7 +169,8 @@ def load_table(dataframe: pd.DataFrame,
167169 connection = connection ,
168170 mode = mode ,
169171 preserve_index = preserve_index ,
170- region = region )
172+ region = region ,
173+ columns = columns )
171174 elif "mysql" in engine .lower ():
172175 Aurora .load_table_mysql (dataframe = dataframe ,
173176 dataframe_type = dataframe_type ,
@@ -177,7 +180,8 @@ def load_table(dataframe: pd.DataFrame,
177180 connection = connection ,
178181 mode = mode ,
179182 preserve_index = preserve_index ,
180- num_files = num_files )
183+ num_files = num_files ,
184+ columns = columns )
181185 else :
182186 raise InvalidEngine (f"{ engine } is not a valid engine. Please use 'mysql' or 'postgres'!" )
183187
@@ -190,7 +194,8 @@ def load_table_postgres(dataframe: pd.DataFrame,
190194 connection : Any ,
191195 mode : str = "append" ,
192196 preserve_index : bool = False ,
193- region : str = "us-east-1" ):
197+ region : str = "us-east-1" ,
198+ columns : Optional [List [str ]] = None ):
194199 """
195200 Load text/CSV files into a Aurora table using a manifest file.
196201 Creates the table if necessary.
@@ -204,6 +209,7 @@ def load_table_postgres(dataframe: pd.DataFrame,
204209 :param mode: append or overwrite
205210 :param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
206211 :param region: AWS S3 bucket region (Required only for postgres engine)
212+ :param columns: List of columns to load
207213 :return: None
208214 """
209215 with connection .cursor () as cursor :
@@ -214,15 +220,17 @@ def load_table_postgres(dataframe: pd.DataFrame,
214220 schema_name = schema_name ,
215221 table_name = table_name ,
216222 preserve_index = preserve_index ,
217- engine = "postgres" )
223+ engine = "postgres" ,
224+ columns = columns )
218225 connection .commit ()
219226 logger .debug ("CREATE TABLE committed." )
220227 for path in load_paths :
221228 sql = Aurora ._get_load_sql (path = path ,
222229 schema_name = schema_name ,
223230 table_name = table_name ,
224231 engine = "postgres" ,
225- region = region )
232+ region = region ,
233+ columns = columns )
226234 Aurora ._load_object_postgres_with_retry (connection = connection , sql = sql )
227235 logger .debug (f"Load committed for: { path } ." )
228236
@@ -257,7 +265,8 @@ def load_table_mysql(dataframe: pd.DataFrame,
257265 connection : Any ,
258266 num_files : int ,
259267 mode : str = "append" ,
260- preserve_index : bool = False ):
268+ preserve_index : bool = False ,
269+ columns : Optional [List [str ]] = None ):
261270 """
262271 Load text/CSV files into a Aurora table using a manifest file.
263272 Creates the table if necessary.
@@ -271,6 +280,7 @@ def load_table_mysql(dataframe: pd.DataFrame,
271280 :param num_files: Number of files to be loaded
272281 :param mode: append or overwrite
273282 :param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
283+ :param columns: List of columns to load
274284 :return: None
275285 """
276286 with connection .cursor () as cursor :
@@ -281,11 +291,13 @@ def load_table_mysql(dataframe: pd.DataFrame,
281291 schema_name = schema_name ,
282292 table_name = table_name ,
283293 preserve_index = preserve_index ,
284- engine = "mysql" )
294+ engine = "mysql" ,
295+ columns = columns )
285296 sql = Aurora ._get_load_sql (path = manifest_path ,
286297 schema_name = schema_name ,
287298 table_name = table_name ,
288- engine = "mysql" )
299+ engine = "mysql" ,
300+ columns = columns )
289301 logger .debug (sql )
290302 cursor .execute (sql )
291303 logger .debug (f"Load done for: { manifest_path } " )
@@ -310,22 +322,40 @@ def _parse_path(path):
310322 return parts [0 ], parts [2 ]
311323
312324 @staticmethod
313- def _get_load_sql (path : str , schema_name : str , table_name : str , engine : str , region : str = "us-east-1" ) -> str :
325+ def _get_load_sql (path : str ,
326+ schema_name : str ,
327+ table_name : str ,
328+ engine : str ,
329+ region : str = "us-east-1" ,
330+ columns : Optional [List [str ]] = None ) -> str :
314331 if "postgres" in engine .lower ():
315332 bucket , key = Aurora ._parse_path (path = path )
333+ if columns is None :
334+ cols_str : str = ""
335+ else :
336+ cols_str = "," .join (columns )
316337 sql : str = ("-- AWS DATA WRANGLER\n "
317338 "SELECT aws_s3.table_import_from_s3(\n "
318339 f"'{ schema_name } .{ table_name } ',\n "
319- "' ',\n "
340+ f"' { cols_str } ',\n "
320341 "'(FORMAT CSV, DELIMITER '','', QUOTE ''\" '', ESCAPE ''\" '')',\n "
321342 f"'({ bucket } ,{ key } ,{ region } )')" )
322343 elif "mysql" in engine .lower ():
344+ if columns is None :
345+ cols_str = ""
346+ else :
347+ # building something like: (@col1,@col2) set col1=@col1,col2=@col2
348+ col_str = [f"@{ x } " for x in columns ]
349+ set_str = [f"{ x } =@{ x } " for x in columns ]
350+ cols_str = f"({ ',' .join (col_str )} ) SET { ',' .join (set_str )} "
351+ logger .debug (f"cols_str: { cols_str } " )
323352 sql = ("-- AWS DATA WRANGLER\n "
324353 f"LOAD DATA FROM S3 MANIFEST '{ path } '\n "
325354 "REPLACE\n "
326355 f"INTO TABLE { schema_name } .{ table_name } \n "
327356 "FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY '\" ' ESCAPED BY '\" '\n "
328- "LINES TERMINATED BY '\\ n'" )
357+ "LINES TERMINATED BY '\\ n'"
358+ f"{ cols_str } " )
329359 else :
330360 raise InvalidEngine (f"{ engine } is not a valid engine. Please use 'mysql' or 'postgres'!" )
331361 return sql
@@ -337,7 +367,8 @@ def _create_table(cursor,
337367 schema_name ,
338368 table_name ,
339369 preserve_index = False ,
340- engine : str = "mysql" ):
370+ engine : str = "mysql" ,
371+ columns : Optional [List [str ]] = None ):
341372 """
342373 Creates Aurora table.
343374
@@ -348,6 +379,7 @@ def _create_table(cursor,
348379 :param table_name: Redshift table name
349380 :param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
350381 :param engine: "mysql" or "postgres"
382+ :param columns: List of columns to load
351383 :return: None
352384 """
353385 sql : str = f"-- AWS DATA WRANGLER\n " \
@@ -364,7 +396,8 @@ def _create_table(cursor,
364396 schema = Aurora ._get_schema (dataframe = dataframe ,
365397 dataframe_type = dataframe_type ,
366398 preserve_index = preserve_index ,
367- engine = engine )
399+ engine = engine ,
400+ columns = columns )
368401 cols_str : str = "" .join ([f"{ col [0 ]} { col [1 ]} ,\n " for col in schema ])[:- 2 ]
369402 sql = f"-- AWS DATA WRANGLER\n " f"CREATE TABLE IF NOT EXISTS { schema_name } .{ table_name } (\n " f"{ cols_str } )"
370403 logger .debug (f"Create table query:\n { sql } " )
@@ -374,7 +407,8 @@ def _create_table(cursor,
374407 def _get_schema (dataframe ,
375408 dataframe_type : str ,
376409 preserve_index : bool ,
377- engine : str = "mysql" ) -> List [Tuple [str , str ]]:
410+ engine : str = "mysql" ,
411+ columns : Optional [List [str ]] = None ) -> List [Tuple [str , str ]]:
378412 schema_built : List [Tuple [str , str ]] = []
379413 if "postgres" in engine .lower ():
380414 convert_func = data_types .pyarrow2postgres
@@ -386,8 +420,9 @@ def _get_schema(dataframe,
386420 pyarrow_schema : List [Tuple [str , str ]] = data_types .extract_pyarrow_schema_from_pandas (
387421 dataframe = dataframe , preserve_index = preserve_index , indexes_position = "right" )
388422 for name , dtype in pyarrow_schema :
389- aurora_type : str = convert_func (dtype )
390- schema_built .append ((name , aurora_type ))
423+ if columns is None or name in columns :
424+ aurora_type : str = convert_func (dtype )
425+ schema_built .append ((name , aurora_type ))
391426 else :
392427 raise InvalidDataframeType (f"{ dataframe_type } is not a valid DataFrame type. Please use 'pandas'!" )
393428 return schema_built
0 commit comments