Skip to content

Commit 3161a50

Browse files
committed
Now Pandas.read_parquet() will return Int64 for integers with null values and the Pandas.to_redshift() also will be able to cast it.
1 parent 1c28de7 commit 3161a50

File tree

5 files changed

+64
-6
lines changed

5 files changed

+64
-6
lines changed

awswrangler/data_types.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -382,24 +382,29 @@ def convert_schema(func: Callable, schema: List[Tuple[str, str]]) -> Dict[str, s
382382

383383
def extract_pyarrow_schema_from_pandas(dataframe: pd.DataFrame,
384384
preserve_index: bool,
385-
indexes_position: str = "right") -> List[Tuple[str, Any]]:
385+
indexes_position: str = "right",
386+
ignore_cols: Optional[List[str]] = None) -> List[Tuple[str, Any]]:
386387
"""
387388
Extract the related Pyarrow schema from any Pandas DataFrame.
388389
389390
:param dataframe: Pandas Dataframe
390391
:param preserve_index: True or False
391392
:param indexes_position: "right" or "left"
393+
:param ignore_cols: List of columns to be ignored
392394
:return: Pyarrow schema (e.g. [("col name": "bigint"), ("col2 name": "int")]
393395
"""
396+
ignore_cols = [] if ignore_cols is None else ignore_cols
394397
cols: List[str] = []
395-
cols_dtypes: Dict[str, str] = {}
398+
cols_dtypes: Dict[str, Optional[str]] = {}
396399
if indexes_position not in ("right", "left"):
397400
raise ValueError(f"indexes_position must be \"right\" or \"left\"")
398401

399402
# Handle exception data types (e.g. Int64, string)
400403
for name, dtype in dataframe.dtypes.to_dict().items():
401404
dtype = str(dtype)
402-
if dtype == "Int64":
405+
if name in ignore_cols:
406+
cols_dtypes[name] = None
407+
elif dtype == "Int64":
403408
cols_dtypes[name] = "int64"
404409
elif dtype == "string":
405410
cols_dtypes[name] = "string"

awswrangler/pandas.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,6 @@ def _data_to_s3_dataset_writer(dataframe: pd.DataFrame,
843843
isolated_dataframe=isolated_dataframe)
844844
objects_paths.append(object_path)
845845
else:
846-
dataframe = Pandas._cast_pandas(dataframe=dataframe, cast_columns=cast_columns)
847846
for keys, subgroup in dataframe.groupby(by=partition_cols, observed=True):
848847
subgroup = subgroup.drop(partition_cols, axis="columns")
849848
if not isinstance(keys, tuple):
@@ -1390,7 +1389,7 @@ def _read_parquet_path(session_primitives: "SessionPrimitives",
13901389
if str(field.type).startswith("int") and field.name != "__index_level_0__"
13911390
]
13921391
logger.debug(f"Converting to Pandas: {path}")
1393-
df = table.to_pandas(use_threads=use_threads, integer_object_nulls=True)
1392+
df = table.to_pandas(use_threads=use_threads, integer_object_nulls=False)
13941393
logger.debug(f"Casting Int64 columns: {path}")
13951394
for c in integers:
13961395
if not str(df[c].dtype).startswith("int"):

awswrangler/redshift.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,9 +431,11 @@ def _get_redshift_schema(dataframe,
431431
varchar_lengths = {} if varchar_lengths is None else varchar_lengths
432432
schema_built: List[Tuple[str, str]] = []
433433
if dataframe_type.lower() == "pandas":
434+
ignore_cols = list(cast_columns.keys()) if cast_columns is not None else None
434435
pyarrow_schema = data_types.extract_pyarrow_schema_from_pandas(dataframe=dataframe,
435436
preserve_index=preserve_index,
436-
indexes_position="right")
437+
indexes_position="right",
438+
ignore_cols=ignore_cols)
437439
for name, dtype in pyarrow_schema:
438440
if (cast_columns is not None) and (name in cast_columns.keys()):
439441
schema_built.append((name, cast_columns[name]))

testing/test_awswrangler/test_pandas.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2536,3 +2536,19 @@ def test_sequential_overwrite(bucket):
25362536
df3 = wr.pandas.read_parquet(path=path)
25372537
assert len(df3.index) == 1
25382538
assert df3.col[0] == 2
2539+
2540+
2541+
def test_read_parquet_int_na(bucket):
2542+
path = f"s3://{bucket}/test_read_parquet_int_na/"
2543+
df = pd.DataFrame({"col": [1] + [pd.NA for _ in range(10_000)]}, dtype="Int64")
2544+
wr.pandas.to_parquet(
2545+
dataframe=df,
2546+
path=path,
2547+
preserve_index=False,
2548+
mode="overwrite",
2549+
procs_cpu_bound=4
2550+
)
2551+
df2 = wr.pandas.read_parquet(path=path)
2552+
assert len(df2.index) == 10_001
2553+
assert len(df2.columns) == 1
2554+
assert df2.dtypes["col"] == "Int64"

testing/test_awswrangler/test_redshift.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,3 +911,39 @@ def test_to_redshift_spark_varchar(session, bucket, redshift_parameters):
911911
for row in rows:
912912
assert len(row) == len(pdf.columns)
913913
conn.close()
914+
915+
916+
def test_to_redshift_int_na(bucket, redshift_parameters):
917+
df = pd.DataFrame({
918+
"id": [1, 2, 3, 4, 5],
919+
"col1": [1, pd.NA, 2, pd.NA, pd.NA],
920+
"col2": [pd.NA, pd.NA, pd.NA, pd.NA, pd.NA],
921+
"col3": [None, None, None, None, None],
922+
"col4": [1, pd.NA, 2, pd.NA, pd.NA]
923+
})
924+
df["col1"] = df["col1"].astype("Int64")
925+
df["col2"] = df["col2"].astype("Int64")
926+
df["col3"] = df["col3"].astype("Int64")
927+
path = f"s3://{bucket}/test_to_redshift_int_na"
928+
wr.pandas.to_redshift(dataframe=df,
929+
path=path,
930+
schema="public",
931+
table="test_to_redshift_int_na",
932+
connection="aws-data-wrangler-redshift",
933+
iam_role=redshift_parameters.get("RedshiftRole"),
934+
mode="overwrite",
935+
preserve_index=False,
936+
cast_columns={
937+
"col4": "INT8"
938+
})
939+
conn = wr.glue.get_connection("aws-data-wrangler-redshift")
940+
with conn.cursor() as cursor:
941+
cursor.execute("SELECT * FROM public.test_to_redshift_int_na")
942+
rows = cursor.fetchall()
943+
assert len(rows) == len(df.index)
944+
for row in rows:
945+
assert len(row) == len(df.columns)
946+
cursor.execute("SELECT SUM(col1) FROM public.test_to_redshift_int_na")
947+
rows = cursor.fetchall()
948+
assert rows[0][0] == 3
949+
conn.close()

0 commit comments

Comments
 (0)