Skip to content

Commit 2dcec02

Browse files
committed
Add columns parameters for Pandas.to_aurora() and Aurora.load_table()
1 parent 7184977 commit 2dcec02

File tree

3 files changed

+151
-33
lines changed

3 files changed

+151
-33
lines changed

awswrangler/aurora.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Union, List, Dict, Tuple, Any
1+
from typing import TYPE_CHECKING, Union, List, Dict, Tuple, Any, Optional
22
from logging import getLogger, Logger, INFO
33
import json
44
import warnings
@@ -137,6 +137,7 @@ def load_table(dataframe: pd.DataFrame,
137137
table_name: str,
138138
connection: Any,
139139
num_files: int,
140+
columns: Optional[List[str]] = None,
140141
mode: str = "append",
141142
preserve_index: bool = False,
142143
engine: str = "mysql",
@@ -152,6 +153,7 @@ def load_table(dataframe: pd.DataFrame,
152153
:param table_name: Aurora table name
153154
:param connection: A PEP 249 compatible connection (Can be generated with Aurora.generate_connection())
154155
:param num_files: Number of files to be loaded
156+
:param columns: List of columns to load
155157
:param mode: append or overwrite
156158
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
157159
:param engine: "mysql" or "postgres"
@@ -167,7 +169,8 @@ def load_table(dataframe: pd.DataFrame,
167169
connection=connection,
168170
mode=mode,
169171
preserve_index=preserve_index,
170-
region=region)
172+
region=region,
173+
columns=columns)
171174
elif "mysql" in engine.lower():
172175
Aurora.load_table_mysql(dataframe=dataframe,
173176
dataframe_type=dataframe_type,
@@ -177,7 +180,8 @@ def load_table(dataframe: pd.DataFrame,
177180
connection=connection,
178181
mode=mode,
179182
preserve_index=preserve_index,
180-
num_files=num_files)
183+
num_files=num_files,
184+
columns=columns)
181185
else:
182186
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
183187

@@ -190,7 +194,8 @@ def load_table_postgres(dataframe: pd.DataFrame,
190194
connection: Any,
191195
mode: str = "append",
192196
preserve_index: bool = False,
193-
region: str = "us-east-1"):
197+
region: str = "us-east-1",
198+
columns: Optional[List[str]] = None):
194199
"""
195200
Load text/CSV files into a Aurora table using a manifest file.
196201
Creates the table if necessary.
@@ -204,6 +209,7 @@ def load_table_postgres(dataframe: pd.DataFrame,
204209
:param mode: append or overwrite
205210
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
206211
:param region: AWS S3 bucket region (Required only for postgres engine)
212+
:param columns: List of columns to load
207213
:return: None
208214
"""
209215
with connection.cursor() as cursor:
@@ -214,15 +220,17 @@ def load_table_postgres(dataframe: pd.DataFrame,
214220
schema_name=schema_name,
215221
table_name=table_name,
216222
preserve_index=preserve_index,
217-
engine="postgres")
223+
engine="postgres",
224+
columns=columns)
218225
connection.commit()
219226
logger.debug("CREATE TABLE committed.")
220227
for path in load_paths:
221228
sql = Aurora._get_load_sql(path=path,
222229
schema_name=schema_name,
223230
table_name=table_name,
224231
engine="postgres",
225-
region=region)
232+
region=region,
233+
columns=columns)
226234
Aurora._load_object_postgres_with_retry(connection=connection, sql=sql)
227235
logger.debug(f"Load committed for: {path}.")
228236

@@ -257,7 +265,8 @@ def load_table_mysql(dataframe: pd.DataFrame,
257265
connection: Any,
258266
num_files: int,
259267
mode: str = "append",
260-
preserve_index: bool = False):
268+
preserve_index: bool = False,
269+
columns: Optional[List[str]] = None):
261270
"""
262271
Load text/CSV files into a Aurora table using a manifest file.
263272
Creates the table if necessary.
@@ -271,6 +280,7 @@ def load_table_mysql(dataframe: pd.DataFrame,
271280
:param num_files: Number of files to be loaded
272281
:param mode: append or overwrite
273282
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
283+
:param columns: List of columns to load
274284
:return: None
275285
"""
276286
with connection.cursor() as cursor:
@@ -281,11 +291,13 @@ def load_table_mysql(dataframe: pd.DataFrame,
281291
schema_name=schema_name,
282292
table_name=table_name,
283293
preserve_index=preserve_index,
284-
engine="mysql")
294+
engine="mysql",
295+
columns=columns)
285296
sql = Aurora._get_load_sql(path=manifest_path,
286297
schema_name=schema_name,
287298
table_name=table_name,
288-
engine="mysql")
299+
engine="mysql",
300+
columns=columns)
289301
logger.debug(sql)
290302
cursor.execute(sql)
291303
logger.debug(f"Load done for: {manifest_path}")
@@ -310,22 +322,40 @@ def _parse_path(path):
310322
return parts[0], parts[2]
311323

312324
@staticmethod
313-
def _get_load_sql(path: str, schema_name: str, table_name: str, engine: str, region: str = "us-east-1") -> str:
325+
def _get_load_sql(path: str,
326+
schema_name: str,
327+
table_name: str,
328+
engine: str,
329+
region: str = "us-east-1",
330+
columns: Optional[List[str]] = None) -> str:
314331
if "postgres" in engine.lower():
315332
bucket, key = Aurora._parse_path(path=path)
333+
if columns is None:
334+
cols_str: str = ""
335+
else:
336+
cols_str = ",".join(columns)
316337
sql: str = ("-- AWS DATA WRANGLER\n"
317338
"SELECT aws_s3.table_import_from_s3(\n"
318339
f"'{schema_name}.{table_name}',\n"
319-
"'',\n"
340+
f"'{cols_str}',\n"
320341
"'(FORMAT CSV, DELIMITER '','', QUOTE ''\"'', ESCAPE ''\"'')',\n"
321342
f"'({bucket},{key},{region})')")
322343
elif "mysql" in engine.lower():
344+
if columns is None:
345+
cols_str = ""
346+
else:
347+
# building something like: (@col1,@col2) set col1=@col1,col2=@col2
348+
col_str = [f"@{x}" for x in columns]
349+
set_str = [f"{x}=@{x}" for x in columns]
350+
cols_str = f"({','.join(col_str)}) SET {','.join(set_str)}"
351+
logger.debug(f"cols_str: {cols_str}")
323352
sql = ("-- AWS DATA WRANGLER\n"
324353
f"LOAD DATA FROM S3 MANIFEST '{path}'\n"
325354
"REPLACE\n"
326355
f"INTO TABLE {schema_name}.{table_name}\n"
327356
"FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY '\"' ESCAPED BY '\"'\n"
328-
"LINES TERMINATED BY '\\n'")
357+
"LINES TERMINATED BY '\\n'"
358+
f"{cols_str}")
329359
else:
330360
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
331361
return sql
@@ -337,7 +367,8 @@ def _create_table(cursor,
337367
schema_name,
338368
table_name,
339369
preserve_index=False,
340-
engine: str = "mysql"):
370+
engine: str = "mysql",
371+
columns: Optional[List[str]] = None):
341372
"""
342373
Creates Aurora table.
343374
@@ -348,6 +379,7 @@ def _create_table(cursor,
348379
:param table_name: Redshift table name
349380
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
350381
:param engine: "mysql" or "postgres"
382+
:param columns: List of columns to load
351383
:return: None
352384
"""
353385
sql: str = f"-- AWS DATA WRANGLER\n" \
@@ -364,7 +396,8 @@ def _create_table(cursor,
364396
schema = Aurora._get_schema(dataframe=dataframe,
365397
dataframe_type=dataframe_type,
366398
preserve_index=preserve_index,
367-
engine=engine)
399+
engine=engine,
400+
columns=columns)
368401
cols_str: str = "".join([f"{col[0]} {col[1]},\n" for col in schema])[:-2]
369402
sql = f"-- AWS DATA WRANGLER\n" f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} (\n" f"{cols_str})"
370403
logger.debug(f"Create table query:\n{sql}")
@@ -374,7 +407,8 @@ def _create_table(cursor,
374407
def _get_schema(dataframe,
375408
dataframe_type: str,
376409
preserve_index: bool,
377-
engine: str = "mysql") -> List[Tuple[str, str]]:
410+
engine: str = "mysql",
411+
columns: Optional[List[str]] = None) -> List[Tuple[str, str]]:
378412
schema_built: List[Tuple[str, str]] = []
379413
if "postgres" in engine.lower():
380414
convert_func = data_types.pyarrow2postgres
@@ -386,8 +420,9 @@ def _get_schema(dataframe,
386420
pyarrow_schema: List[Tuple[str, str]] = data_types.extract_pyarrow_schema_from_pandas(
387421
dataframe=dataframe, preserve_index=preserve_index, indexes_position="right")
388422
for name, dtype in pyarrow_schema:
389-
aurora_type: str = convert_func(dtype)
390-
schema_built.append((name, aurora_type))
423+
if columns is None or name in columns:
424+
aurora_type: str = convert_func(dtype)
425+
schema_built.append((name, aurora_type))
391426
else:
392427
raise InvalidDataframeType(f"{dataframe_type} is not a valid DataFrame type. Please use 'pandas'!")
393428
return schema_built

awswrangler/pandas.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,7 @@ def to_s3(self,
870870
partition_cols = [Athena.normalize_column_name(x) for x in partition_cols]
871871
logger.debug(f"partition_cols: {partition_cols}")
872872
if extra_args is not None and "columns" in extra_args:
873-
extra_args["columns"] = [Athena.normalize_column_name(x) for x in extra_args["columns"]]
873+
extra_args["columns"] = [Athena.normalize_column_name(x) for x in extra_args["columns"]] # type: ignore
874874
dataframe = Pandas.drop_duplicated_columns(dataframe=dataframe, inplace=inplace)
875875
if compression is not None:
876876
compression = compression.lower()
@@ -1691,6 +1691,7 @@ def to_aurora(self,
16911691
load_paths=load_paths,
16921692
schema_name=schema,
16931693
table_name=table,
1694+
columns=columns,
16941695
connection=connection,
16951696
num_files=len(paths),
16961697
mode=mode,

testing/test_awswrangler/test_pandas.py

Lines changed: 97 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2206,24 +2206,106 @@ def test_range_index(bucket, database):
22062206
def test_to_csv_columns(bucket, database):
22072207
path = f"s3://{bucket}/test_to_csv_columns"
22082208
wr.s3.delete_objects(path=path)
2209-
df = pd.DataFrame({
2210-
"A": [1, 2, 3],
2211-
"B": [4, 5, 6],
2212-
"C": ["foo", "boo", "bar"]
2213-
})
2209+
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": ["foo", "boo", "bar"]})
22142210
wr.s3.delete_objects(path=path)
2215-
wr.pandas.to_csv(
2216-
dataframe=df,
2217-
database=database,
2218-
path=path,
2219-
columns=["A", "B"],
2220-
mode="overwrite",
2221-
preserve_index=False,
2222-
procs_cpu_bound=1,
2223-
inplace=False
2224-
)
2211+
wr.pandas.to_csv(dataframe=df,
2212+
database=database,
2213+
path=path,
2214+
columns=["A", "B"],
2215+
mode="overwrite",
2216+
preserve_index=False,
2217+
procs_cpu_bound=1,
2218+
inplace=False)
22252219
sleep(10)
22262220
df2 = wr.pandas.read_sql_athena(database=database, sql="SELECT * FROM test_to_csv_columns")
22272221
wr.s3.delete_objects(path=path)
22282222
assert len(df.columns) == len(df2.columns) + 1
22292223
assert len(df.index) == len(df2.index)
2224+
2225+
2226+
def test_aurora_postgres_load_columns(bucket, postgres_parameters):
2227+
df = pd.DataFrame({"id": [1, 2, 3], "value": ["foo", "boo", "bar"], "value2": [4, 5, 6]})
2228+
conn = Aurora.generate_connection(database="postgres",
2229+
host=postgres_parameters["PostgresAddress"],
2230+
port=3306,
2231+
user="test",
2232+
password=postgres_parameters["Password"],
2233+
engine="postgres")
2234+
path = f"s3://{bucket}/test_aurora_postgres_load_columns"
2235+
wr.pandas.to_aurora(dataframe=df,
2236+
connection=conn,
2237+
schema="public",
2238+
table="test_aurora_postgres_load_columns",
2239+
mode="overwrite",
2240+
temp_s3_path=path,
2241+
engine="postgres",
2242+
columns=["id", "value"])
2243+
wr.pandas.to_aurora(dataframe=df,
2244+
connection=conn,
2245+
schema="public",
2246+
table="test_aurora_postgres_load_columns",
2247+
mode="append",
2248+
temp_s3_path=path,
2249+
engine="postgres",
2250+
columns=["value"])
2251+
with conn.cursor() as cursor:
2252+
cursor.execute("SELECT * FROM public.test_aurora_postgres_load_columns")
2253+
rows = cursor.fetchall()
2254+
assert len(rows) == len(df.index) * 2
2255+
assert rows[0][0] == 1
2256+
assert rows[1][0] == 2
2257+
assert rows[2][0] == 3
2258+
assert rows[3][0] is None
2259+
assert rows[4][0] is None
2260+
assert rows[5][0] is None
2261+
assert rows[0][1] == "foo"
2262+
assert rows[1][1] == "boo"
2263+
assert rows[2][1] == "bar"
2264+
assert rows[3][1] == "foo"
2265+
assert rows[4][1] == "boo"
2266+
assert rows[5][1] == "bar"
2267+
conn.close()
2268+
2269+
2270+
def test_aurora_mysql_load_columns(bucket, mysql_parameters):
2271+
df = pd.DataFrame({"id": [1, 2, 3], "value": ["foo", "boo", "bar"], "value2": [4, 5, 6]})
2272+
conn = Aurora.generate_connection(database="mysql",
2273+
host=mysql_parameters["MysqlAddress"],
2274+
port=3306,
2275+
user="test",
2276+
password=mysql_parameters["Password"],
2277+
engine="mysql")
2278+
path = f"s3://{bucket}/test_aurora_mysql_load_columns"
2279+
wr.pandas.to_aurora(dataframe=df,
2280+
connection=conn,
2281+
schema="test",
2282+
table="test_aurora_mysql_load_columns",
2283+
mode="overwrite",
2284+
temp_s3_path=path,
2285+
engine="mysql",
2286+
columns=["id", "value"])
2287+
wr.pandas.to_aurora(dataframe=df,
2288+
connection=conn,
2289+
schema="test",
2290+
table="test_aurora_mysql_load_columns",
2291+
mode="append",
2292+
temp_s3_path=path,
2293+
engine=" mysql",
2294+
columns=["value"])
2295+
with conn.cursor() as cursor:
2296+
cursor.execute("SELECT * FROM test.test_aurora_mysql_load_columns")
2297+
rows = cursor.fetchall()
2298+
assert len(rows) == len(df.index) * 2
2299+
assert rows[0][0] == 1
2300+
assert rows[1][0] == 2
2301+
assert rows[2][0] == 3
2302+
assert rows[3][0] is None
2303+
assert rows[4][0] is None
2304+
assert rows[5][0] is None
2305+
assert rows[0][1] == "foo"
2306+
assert rows[1][1] == "boo"
2307+
assert rows[2][1] == "bar"
2308+
assert rows[3][1] == "foo"
2309+
assert rows[4][1] == "boo"
2310+
assert rows[5][1] == "bar"
2311+
conn.close()

0 commit comments

Comments
 (0)