Skip to content

Commit 632643a

Browse files
authored
Merge pull request #49 from awslabs/redshift-cast-columns
Fix pandas to redshift cast feature
2 parents 98a5e0d + fccfdf7 commit 632643a

File tree

2 files changed

+73
-29
lines changed

2 files changed

+73
-29
lines changed

awswrangler/pandas.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Dict, List, Tuple, Optional, Any
12
from io import BytesIO, StringIO
23
import multiprocessing as mp
34
import logging
@@ -854,20 +855,20 @@ def write_parquet_dataframe(dataframe, path, preserve_index, compression, fs, ca
854855

855856
def to_redshift(
856857
self,
857-
dataframe,
858-
path,
859-
connection,
860-
schema,
861-
table,
862-
iam_role,
863-
diststyle="AUTO",
864-
distkey=None,
865-
sortstyle="COMPOUND",
866-
sortkey=None,
867-
preserve_index=False,
868-
mode="append",
869-
cast_columns=None,
870-
):
858+
dataframe: pd.DataFrame,
859+
path: str,
860+
connection: Any,
861+
schema: str,
862+
table: str,
863+
iam_role: str,
864+
diststyle: str = "AUTO",
865+
distkey: Optional[str] = None,
866+
sortstyle: str = "COMPOUND",
867+
sortkey: Optional[str] = None,
868+
preserve_index: bool = False,
869+
mode: str = "append",
870+
cast_columns: Optional[Dict[str, str]] = None,
871+
) -> None:
871872
"""
872873
Load Pandas Dataframe as a Table on Amazon Redshift
873874
@@ -888,28 +889,30 @@ def to_redshift(
888889
"""
889890
if cast_columns is None:
890891
cast_columns = {}
891-
cast_columns_parquet = {}
892+
cast_columns_parquet: Dict = {}
892893
else:
893-
cast_columns_parquet = data_types.convert_schema(func=data_types.redshift2athena, schema=cast_columns)
894+
cast_columns_tuples: List[Tuple[str, str]] = [(k, v) for k, v in cast_columns.items()]
895+
cast_columns_parquet = data_types.convert_schema(func=data_types.redshift2athena,
896+
schema=cast_columns_tuples)
894897
if path[-1] != "/":
895898
path += "/"
896899
self._session.s3.delete_objects(path=path)
897-
num_rows = len(dataframe.index)
900+
num_rows: int = len(dataframe.index)
898901
logger.debug(f"Number of rows: {num_rows}")
899902
if num_rows < MIN_NUMBER_OF_ROWS_TO_DISTRIBUTE:
900-
num_partitions = 1
903+
num_partitions: int = 1
901904
else:
902-
num_slices = self._session.redshift.get_number_of_slices(redshift_conn=connection)
905+
num_slices: int = self._session.redshift.get_number_of_slices(redshift_conn=connection)
903906
logger.debug(f"Number of slices on Redshift: {num_slices}")
904907
num_partitions = num_slices
905908
logger.debug(f"Number of partitions calculated: {num_partitions}")
906-
objects_paths = self.to_parquet(dataframe=dataframe,
907-
path=path,
908-
preserve_index=preserve_index,
909-
mode="append",
910-
procs_cpu_bound=num_partitions,
911-
cast_columns=cast_columns_parquet)
912-
manifest_path = f"{path}manifest.json"
909+
objects_paths: List[str] = self.to_parquet(dataframe=dataframe,
910+
path=path,
911+
preserve_index=preserve_index,
912+
mode="append",
913+
procs_cpu_bound=num_partitions,
914+
cast_columns=cast_columns_parquet)
915+
manifest_path: str = f"{path}manifest.json"
913916
self._session.redshift.write_load_manifest(manifest_path=manifest_path, objects_paths=objects_paths)
914917
self._session.redshift.load_table(
915918
dataframe=dataframe,

testing/test_awswrangler/test_redshift.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import json
22
import logging
3+
from datetime import date, datetime
34

45
import pytest
56
import boto3
6-
import pandas
7+
import pandas as pd
78
from pyspark.sql import SparkSession
89
import pg8000
910

@@ -80,7 +81,7 @@ def test_to_redshift_pandas(session, bucket, redshift_parameters, sample_name, m
8081
dates = ["date"]
8182
if sample_name == "nano":
8283
dates = ["date", "time"]
83-
dataframe = pandas.read_csv(f"data_samples/{sample_name}.csv", parse_dates=dates, infer_datetime_format=True)
84+
dataframe = pd.read_csv(f"data_samples/{sample_name}.csv", parse_dates=dates, infer_datetime_format=True)
8485
dataframe["date"] = dataframe["date"].dt.date
8586
con = Redshift.generate_connection(
8687
database="test",
@@ -113,6 +114,46 @@ def test_to_redshift_pandas(session, bucket, redshift_parameters, sample_name, m
113114
assert len(list(dataframe.columns)) + 1 == len(list(rows[0]))
114115

115116

117+
def test_to_redshift_pandas_cast(session, bucket, redshift_parameters):
118+
df = pd.DataFrame({
119+
"id": [1, 2, 3],
120+
"name": ["name1", "name2", "name3"],
121+
"foo": [None, None, None],
122+
"boo": [date(2020, 1, 1), None, None],
123+
"bar": [datetime(2021, 1, 1), None, None]})
124+
schema = {
125+
"id": "BIGINT",
126+
"name": "VARCHAR",
127+
"foo": "REAL",
128+
"boo": "DATE",
129+
"bar": "TIMESTAMP"}
130+
con = Redshift.generate_connection(
131+
database="test",
132+
host=redshift_parameters.get("RedshiftAddress"),
133+
port=redshift_parameters.get("RedshiftPort"),
134+
user="test",
135+
password=redshift_parameters.get("RedshiftPassword"),
136+
)
137+
path = f"s3://{bucket}/redshift-load/"
138+
session.pandas.to_redshift(dataframe=df,
139+
path=path,
140+
schema="public",
141+
table="test",
142+
connection=con,
143+
iam_role=redshift_parameters.get("RedshiftRole"),
144+
mode="overwrite",
145+
preserve_index=False,
146+
cast_columns=schema)
147+
cursor = con.cursor()
148+
cursor.execute("SELECT * from public.test")
149+
rows = cursor.fetchall()
150+
cursor.close()
151+
con.close()
152+
print(rows)
153+
assert len(df.index) == len(rows)
154+
assert len(list(df.columns)) == len(list(rows[0]))
155+
156+
116157
@pytest.mark.parametrize(
117158
"sample_name,mode,factor,diststyle,distkey,exc,sortstyle,sortkey",
118159
[
@@ -125,7 +166,7 @@ def test_to_redshift_pandas(session, bucket, redshift_parameters, sample_name, m
125166
)
126167
def test_to_redshift_pandas_exceptions(session, bucket, redshift_parameters, sample_name, mode, factor, diststyle,
127168
distkey, sortstyle, sortkey, exc):
128-
dataframe = pandas.read_csv(f"data_samples/{sample_name}.csv")
169+
dataframe = pd.read_csv(f"data_samples/{sample_name}.csv")
129170
con = Redshift.generate_connection(
130171
database="test",
131172
host=redshift_parameters.get("RedshiftAddress"),

0 commit comments

Comments
 (0)