Skip to content

Commit 65904ce

Browse files
committed
Fix Redshift data type conversions
1 parent 48c4dfc commit 65904ce

File tree

4 files changed

+66
-22
lines changed

4 files changed

+66
-22
lines changed

awswrangler/glue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def type_pandas2athena(dtype):
5656
return "double"
5757
elif dtype == "bool":
5858
return "boolean"
59-
elif dtype == "object" and isinstance(dtype, str):
59+
elif dtype == "object":
6060
return "string"
6161
elif dtype[:10] == "datetime64":
6262
return "timestamp"

awswrangler/redshift.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -341,11 +341,9 @@ def _get_redshift_schema(dataframe, dataframe_type, preserve_index=False):
341341
dtype = str(dataframe.index.dtype)
342342
redshift_type = Redshift._type_pandas2redshift(dtype)
343343
schema_built.append((name, redshift_type))
344-
for col in dataframe.columns:
345-
name = str(col)
346-
dtype = str(dataframe[name].dtype)
344+
for col, dtype in dataframe.dtypes:
347345
redshift_type = Redshift._type_pandas2redshift(dtype)
348-
schema_built.append((name, redshift_type))
346+
schema_built.append((col, redshift_type))
349347
elif dataframe_type == "spark":
350348
for name, dtype in dataframe.dtypes:
351349
redshift_type = Redshift._type_spark2redshift(dtype)
@@ -377,17 +375,17 @@ def _type_pandas2redshift(dtype):
377375
@staticmethod
378376
def _type_spark2redshift(dtype):
379377
dtype = dtype.lower()
380-
if dtype == "int":
381-
return "INTEGER"
382-
elif dtype == "long":
378+
if dtype in ["smallint", "int", "bigint"]:
383379
return "BIGINT"
384380
elif dtype == "float":
381+
return "FLOAT4"
382+
elif dtype == "double":
385383
return "FLOAT8"
386384
elif dtype == "bool":
387385
return "BOOLEAN"
386+
elif dtype == "timestamp":
387+
return "TIMESTAMP"
388388
elif dtype == "string":
389389
return "VARCHAR(256)"
390-
elif dtype[:10] == "datetime.datetime":
391-
return "TIMESTAMP"
392390
else:
393391
raise UnsupportedType("Unsupported Spark type: " + dtype)

awswrangler/spark.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from pyspark.sql.functions import pandas_udf, PandasUDFType
66
from pyspark.sql.functions import floor, rand
7+
from pyspark.sql.types import TimestampType
78

89
from awswrangler.exceptions import MissingBatchDetected
910

@@ -16,9 +17,29 @@ class Spark:
1617
def __init__(self, session):
1718
self._session = session
1819

19-
def read_csv(self, path):
20+
def read_csv(self, **args):
2021
spark = self._session.spark_session
21-
return spark.read.csv(path=path, header=True)
22+
return spark.read.csv(**args)
23+
24+
@staticmethod
25+
def _extract_casts(dtypes):
26+
casts = {}
27+
for col, dtype in dtypes:
28+
if dtype in ["smallint", "int", "bigint"]:
29+
casts[col] = "Int64"
30+
elif dtype == "object":
31+
casts[col] = "str"
32+
logger.debug(f"casts: {casts}")
33+
return casts
34+
35+
@staticmethod
36+
def date2timestamp(dataframe):
37+
for col, dtype in dataframe.dtypes:
38+
if dtype == "date":
39+
dataframe = dataframe.withColumn(
40+
col, dataframe[col].cast(TimestampType()))
41+
logger.warning(f"Casting column {col} from date to timestamp!")
42+
return dataframe
2243

2344
def to_redshift(
2445
self,
@@ -57,6 +78,7 @@ def to_redshift(
5778
path += "/"
5879
self._session.s3.delete_objects(path=path)
5980
spark = self._session.spark_session
81+
dataframe = Spark.date2timestamp(dataframe)
6082
dataframe.cache()
6183
num_rows = dataframe.count()
6284
logger.info(f"Number of rows: {num_rows}")
@@ -72,6 +94,7 @@ def to_redshift(
7294
logger.debug(f"Number of partitions calculated: {num_partitions}")
7395
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
7496
session_primitives = self._session.primitives
97+
casts = Spark._extract_casts(dataframe.dtypes)
7598

7699
@pandas_udf(returnType="objects_paths string",
77100
functionType=PandasUDFType.GROUPED_MAP)
@@ -83,7 +106,7 @@ def write(pandas_dataframe):
83106
preserve_index=False,
84107
mode="append",
85108
procs_cpu_bound=1,
86-
)
109+
cast_columns=casts)
87110
return pandas.DataFrame.from_dict({"objects_paths": paths})
88111

89112
df_objects_paths = (dataframe.withColumn(

testing/test_awswrangler/test_redshift.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,22 @@ def redshift_parameters(cloudformation_outputs):
8686
def test_to_redshift_pandas(session, bucket, redshift_parameters, sample_name,
8787
mode, factor, diststyle, distkey, sortstyle,
8888
sortkey):
89+
if sample_name == "micro":
90+
dates = ["date"]
91+
if sample_name == "small":
92+
dates = ["date"]
93+
if sample_name == "nano":
94+
dates = ["date", "time"]
95+
dataframe = pandas.read_csv(f"data_samples/{sample_name}.csv",
96+
parse_dates=dates,
97+
infer_datetime_format=True)
8998
con = Redshift.generate_connection(
9099
database="test",
91100
host=redshift_parameters.get("RedshiftAddress"),
92101
port=redshift_parameters.get("RedshiftPort"),
93102
user="test",
94103
password=redshift_parameters.get("RedshiftPassword"),
95104
)
96-
dataframe = pandas.read_csv(f"data_samples/{sample_name}.csv")
97105
path = f"s3://{bucket}/redshift-load/"
98106
session.pandas.to_redshift(
99107
dataframe=dataframe,
@@ -110,11 +118,12 @@ def test_to_redshift_pandas(session, bucket, redshift_parameters, sample_name,
110118
preserve_index=False,
111119
)
112120
cursor = con.cursor()
113-
cursor.execute("SELECT COUNT(*) as counter from public.test")
114-
counter = cursor.fetchall()[0][0]
121+
cursor.execute("SELECT * from public.test")
122+
rows = cursor.fetchall()
115123
cursor.close()
116124
con.close()
117-
assert len(dataframe.index) * factor == counter
125+
assert len(dataframe.index) * factor == len(rows)
126+
assert len(list(dataframe.columns)) == len(list(rows[0]))
118127

119128

120129
@pytest.mark.parametrize(
@@ -135,14 +144,14 @@ def test_to_redshift_pandas(session, bucket, redshift_parameters, sample_name,
135144
def test_to_redshift_pandas_exceptions(session, bucket, redshift_parameters,
136145
sample_name, mode, factor, diststyle,
137146
distkey, sortstyle, sortkey, exc):
147+
dataframe = pandas.read_csv(f"data_samples/{sample_name}.csv")
138148
con = Redshift.generate_connection(
139149
database="test",
140150
host=redshift_parameters.get("RedshiftAddress"),
141151
port=redshift_parameters.get("RedshiftPort"),
142152
user="test",
143153
password=redshift_parameters.get("RedshiftPassword"),
144154
)
145-
dataframe = pandas.read_csv(f"data_samples/{sample_name}.csv")
146155
path = f"s3://{bucket}/redshift-load/"
147156
with pytest.raises(exc):
148157
assert session.pandas.to_redshift(
@@ -180,7 +189,20 @@ def test_to_redshift_spark(session, bucket, redshift_parameters, sample_name,
180189
mode, factor, diststyle, distkey, sortstyle,
181190
sortkey):
182191
path = f"data_samples/{sample_name}.csv"
183-
dataframe = session.spark.read_csv(path=path)
192+
if sample_name == "micro":
193+
schema = "id SMALLINT, name STRING, value FLOAT, date TIMESTAMP"
194+
timestamp_format = "yyyy-MM-dd"
195+
elif sample_name == "small":
196+
schema = "id BIGINT, name STRING, date DATE"
197+
timestamp_format = "dd-MM-yy"
198+
elif sample_name == "nano":
199+
schema = "id INTEGER, name STRING, value DOUBLE, date TIMESTAMP, time TIMESTAMP"
200+
timestamp_format = "yyyy-MM-dd"
201+
dataframe = session.spark.read_csv(path=path,
202+
schema=schema,
203+
timestampFormat=timestamp_format,
204+
dateFormat=timestamp_format,
205+
header=True)
184206
con = Redshift.generate_connection(
185207
database="test",
186208
host=redshift_parameters.get("RedshiftAddress"),
@@ -203,11 +225,12 @@ def test_to_redshift_spark(session, bucket, redshift_parameters, sample_name,
203225
min_num_partitions=2,
204226
)
205227
cursor = con.cursor()
206-
cursor.execute("SELECT COUNT(*) as counter from public.test")
207-
counter = cursor.fetchall()[0][0]
228+
cursor.execute("SELECT * from public.test")
229+
rows = cursor.fetchall()
208230
cursor.close()
209231
con.close()
210-
assert dataframe.count() * factor == counter
232+
assert (dataframe.count() * factor) == len(rows)
233+
assert len(list(dataframe.columns)) == len(list(rows[0]))
211234

212235

213236
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)