Skip to content

Commit f9617f0

Browse files
authored
Merge pull request #117 from awslabs/aurora-load-columns
Add columns parameters to Pandas.to_aurora() and Pandas.to_csv()
2 parents 60ee9ae + 2dcec02 commit f9617f0

File tree

7 files changed

+247
-73
lines changed

7 files changed

+247
-73
lines changed

awswrangler/aurora.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Union, List, Dict, Tuple, Any
1+
from typing import TYPE_CHECKING, Union, List, Dict, Tuple, Any, Optional
22
from logging import getLogger, Logger, INFO
33
import json
44
import warnings
@@ -137,6 +137,7 @@ def load_table(dataframe: pd.DataFrame,
137137
table_name: str,
138138
connection: Any,
139139
num_files: int,
140+
columns: Optional[List[str]] = None,
140141
mode: str = "append",
141142
preserve_index: bool = False,
142143
engine: str = "mysql",
@@ -152,6 +153,7 @@ def load_table(dataframe: pd.DataFrame,
152153
:param table_name: Aurora table name
153154
:param connection: A PEP 249 compatible connection (Can be generated with Aurora.generate_connection())
154155
:param num_files: Number of files to be loaded
156+
:param columns: List of columns to load
155157
:param mode: append or overwrite
156158
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
157159
:param engine: "mysql" or "postgres"
@@ -167,7 +169,8 @@ def load_table(dataframe: pd.DataFrame,
167169
connection=connection,
168170
mode=mode,
169171
preserve_index=preserve_index,
170-
region=region)
172+
region=region,
173+
columns=columns)
171174
elif "mysql" in engine.lower():
172175
Aurora.load_table_mysql(dataframe=dataframe,
173176
dataframe_type=dataframe_type,
@@ -177,7 +180,8 @@ def load_table(dataframe: pd.DataFrame,
177180
connection=connection,
178181
mode=mode,
179182
preserve_index=preserve_index,
180-
num_files=num_files)
183+
num_files=num_files,
184+
columns=columns)
181185
else:
182186
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
183187

@@ -190,7 +194,8 @@ def load_table_postgres(dataframe: pd.DataFrame,
190194
connection: Any,
191195
mode: str = "append",
192196
preserve_index: bool = False,
193-
region: str = "us-east-1"):
197+
region: str = "us-east-1",
198+
columns: Optional[List[str]] = None):
194199
"""
195200
Load text/CSV files into a Aurora table using a manifest file.
196201
Creates the table if necessary.
@@ -204,6 +209,7 @@ def load_table_postgres(dataframe: pd.DataFrame,
204209
:param mode: append or overwrite
205210
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
206211
:param region: AWS S3 bucket region (Required only for postgres engine)
212+
:param columns: List of columns to load
207213
:return: None
208214
"""
209215
with connection.cursor() as cursor:
@@ -214,15 +220,17 @@ def load_table_postgres(dataframe: pd.DataFrame,
214220
schema_name=schema_name,
215221
table_name=table_name,
216222
preserve_index=preserve_index,
217-
engine="postgres")
223+
engine="postgres",
224+
columns=columns)
218225
connection.commit()
219226
logger.debug("CREATE TABLE committed.")
220227
for path in load_paths:
221228
sql = Aurora._get_load_sql(path=path,
222229
schema_name=schema_name,
223230
table_name=table_name,
224231
engine="postgres",
225-
region=region)
232+
region=region,
233+
columns=columns)
226234
Aurora._load_object_postgres_with_retry(connection=connection, sql=sql)
227235
logger.debug(f"Load committed for: {path}.")
228236

@@ -257,7 +265,8 @@ def load_table_mysql(dataframe: pd.DataFrame,
257265
connection: Any,
258266
num_files: int,
259267
mode: str = "append",
260-
preserve_index: bool = False):
268+
preserve_index: bool = False,
269+
columns: Optional[List[str]] = None):
261270
"""
262271
Load text/CSV files into a Aurora table using a manifest file.
263272
Creates the table if necessary.
@@ -271,6 +280,7 @@ def load_table_mysql(dataframe: pd.DataFrame,
271280
:param num_files: Number of files to be loaded
272281
:param mode: append or overwrite
273282
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
283+
:param columns: List of columns to load
274284
:return: None
275285
"""
276286
with connection.cursor() as cursor:
@@ -281,11 +291,13 @@ def load_table_mysql(dataframe: pd.DataFrame,
281291
schema_name=schema_name,
282292
table_name=table_name,
283293
preserve_index=preserve_index,
284-
engine="mysql")
294+
engine="mysql",
295+
columns=columns)
285296
sql = Aurora._get_load_sql(path=manifest_path,
286297
schema_name=schema_name,
287298
table_name=table_name,
288-
engine="mysql")
299+
engine="mysql",
300+
columns=columns)
289301
logger.debug(sql)
290302
cursor.execute(sql)
291303
logger.debug(f"Load done for: {manifest_path}")
@@ -310,22 +322,40 @@ def _parse_path(path):
310322
return parts[0], parts[2]
311323

312324
@staticmethod
313-
def _get_load_sql(path: str, schema_name: str, table_name: str, engine: str, region: str = "us-east-1") -> str:
325+
def _get_load_sql(path: str,
326+
schema_name: str,
327+
table_name: str,
328+
engine: str,
329+
region: str = "us-east-1",
330+
columns: Optional[List[str]] = None) -> str:
314331
if "postgres" in engine.lower():
315332
bucket, key = Aurora._parse_path(path=path)
333+
if columns is None:
334+
cols_str: str = ""
335+
else:
336+
cols_str = ",".join(columns)
316337
sql: str = ("-- AWS DATA WRANGLER\n"
317338
"SELECT aws_s3.table_import_from_s3(\n"
318339
f"'{schema_name}.{table_name}',\n"
319-
"'',\n"
340+
f"'{cols_str}',\n"
320341
"'(FORMAT CSV, DELIMITER '','', QUOTE ''\"'', ESCAPE ''\"'')',\n"
321342
f"'({bucket},{key},{region})')")
322343
elif "mysql" in engine.lower():
344+
if columns is None:
345+
cols_str = ""
346+
else:
347+
# building something like: (@col1,@col2) set col1=@col1,col2=@col2
348+
col_str = [f"@{x}" for x in columns]
349+
set_str = [f"{x}=@{x}" for x in columns]
350+
cols_str = f"({','.join(col_str)}) SET {','.join(set_str)}"
351+
logger.debug(f"cols_str: {cols_str}")
323352
sql = ("-- AWS DATA WRANGLER\n"
324353
f"LOAD DATA FROM S3 MANIFEST '{path}'\n"
325354
"REPLACE\n"
326355
f"INTO TABLE {schema_name}.{table_name}\n"
327356
"FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY '\"' ESCAPED BY '\"'\n"
328-
"LINES TERMINATED BY '\\n'")
357+
"LINES TERMINATED BY '\\n'"
358+
f"{cols_str}")
329359
else:
330360
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
331361
return sql
@@ -337,7 +367,8 @@ def _create_table(cursor,
337367
schema_name,
338368
table_name,
339369
preserve_index=False,
340-
engine: str = "mysql"):
370+
engine: str = "mysql",
371+
columns: Optional[List[str]] = None):
341372
"""
342373
Creates Aurora table.
343374
@@ -348,6 +379,7 @@ def _create_table(cursor,
348379
:param table_name: Redshift table name
349380
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
350381
:param engine: "mysql" or "postgres"
382+
:param columns: List of columns to load
351383
:return: None
352384
"""
353385
sql: str = f"-- AWS DATA WRANGLER\n" \
@@ -364,7 +396,8 @@ def _create_table(cursor,
364396
schema = Aurora._get_schema(dataframe=dataframe,
365397
dataframe_type=dataframe_type,
366398
preserve_index=preserve_index,
367-
engine=engine)
399+
engine=engine,
400+
columns=columns)
368401
cols_str: str = "".join([f"{col[0]} {col[1]},\n" for col in schema])[:-2]
369402
sql = f"-- AWS DATA WRANGLER\n" f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} (\n" f"{cols_str})"
370403
logger.debug(f"Create table query:\n{sql}")
@@ -374,7 +407,8 @@ def _create_table(cursor,
374407
def _get_schema(dataframe,
375408
dataframe_type: str,
376409
preserve_index: bool,
377-
engine: str = "mysql") -> List[Tuple[str, str]]:
410+
engine: str = "mysql",
411+
columns: Optional[List[str]] = None) -> List[Tuple[str, str]]:
378412
schema_built: List[Tuple[str, str]] = []
379413
if "postgres" in engine.lower():
380414
convert_func = data_types.pyarrow2postgres
@@ -386,8 +420,9 @@ def _get_schema(dataframe,
386420
pyarrow_schema: List[Tuple[str, str]] = data_types.extract_pyarrow_schema_from_pandas(
387421
dataframe=dataframe, preserve_index=preserve_index, indexes_position="right")
388422
for name, dtype in pyarrow_schema:
389-
aurora_type: str = convert_func(dtype)
390-
schema_built.append((name, aurora_type))
423+
if columns is None or name in columns:
424+
aurora_type: str = convert_func(dtype)
425+
schema_built.append((name, aurora_type))
391426
else:
392427
raise InvalidDataframeType(f"{dataframe_type} is not a valid DataFrame type. Please use 'pandas'!")
393428
return schema_built

awswrangler/data_types.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -370,8 +370,8 @@ def extract_pyarrow_schema_from_pandas(dataframe: pd.DataFrame,
370370
:param indexes_position: "right" or "left"
371371
:return: Pyarrow schema (e.g. [("col name": "bigint"), ("col2 name": "int")]
372372
"""
373-
cols = []
374-
cols_dtypes = {}
373+
cols: List[str] = []
374+
cols_dtypes: Dict[str, str] = {}
375375
if indexes_position not in ("right", "left"):
376376
raise ValueError(f"indexes_position must be \"right\" or \"left\"")
377377

@@ -384,10 +384,10 @@ def extract_pyarrow_schema_from_pandas(dataframe: pd.DataFrame,
384384
cols.append(name)
385385

386386
# Filling cols_dtypes and indexes
387-
indexes = []
387+
indexes: List[str] = []
388388
for field in pa.Schema.from_pandas(df=dataframe[cols], preserve_index=preserve_index):
389389
name = str(field.name)
390-
dtype = field.type
390+
dtype = str(field.type)
391391
cols_dtypes[name] = dtype
392392
if name not in dataframe.columns:
393393
indexes.append(name)

0 commit comments

Comments
 (0)