Skip to content

Commit 95e37bf

Browse files
robert-schmidtkeRobert SchmidtkekukushkingLeonLuttenberger
authored
fix: Index columns removed on s3.to_parquet (#2655)
* first go at a failing test * pass missing dataset flag in test * because we partition, do not specify full parquet paths during write * use proper path in tests * use reset_index to allow dropping the entire index * test partitioning on full and partial index * need to validate schema on read for issue to surface * need to sort on index * cross-test without partitioning * print assertion error for remote debugging * simplify test to just assert schema validation * consistently handle regular and index columns casts * use equality assertion utility, drop unnecessary sort * add index partition test * reformat * undo categorical-specific dataframe creation in test * try again to expect the right dtypes * pull out toparquet kwargs * expect test to fail when using modin and partitioning on full index * manually assert unpartitioned index is still present, then reset full index * handle change in promotion kwargs for pyarrow 14+ * move packaging import to correct location * fix types for promotion kwargs * test and handle unnamed index levels as well --------- Co-authored-by: Robert Schmidtke <[email protected]> Co-authored-by: kukushking <[email protected]> Co-authored-by: Leon Luttenberger <[email protected]>
1 parent d6caa93 commit 95e37bf

File tree

4 files changed

+83
-6
lines changed

4 files changed

+83
-6
lines changed

awswrangler/_data_types.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -563,10 +563,12 @@ def pyarrow_types_from_pandas( # noqa: PLR0912,PLR0915
563563
for field in fields:
564564
name = str(field.name)
565565
# Check if any of the index columns must be ignored
566-
if name not in ignore_cols:
566+
if name in ignore_cols:
567+
cols_dtypes[name] = None
568+
else:
567569
_logger.debug("Inferring PyArrow type from index: %s", name)
568570
cols_dtypes[name] = field.type
569-
indexes.append(name)
571+
indexes.append(name)
570572

571573
# Merging Index
572574
sorted_cols: list[str] = indexes + list(df.columns) if index_left is True else list(df.columns) + indexes
@@ -693,13 +695,26 @@ def pyarrow_schema_from_pandas(
693695
df=df, index=index, ignore_cols=ignore_plus
694696
)
695697
for k, v in casts.items():
696-
if (k in df.columns) and (k not in ignore):
698+
if (k not in ignore) and (k in df.columns or _is_index_name(k, df.index)):
697699
columns_types[k] = athena2pyarrow(dtype=v)
698700
columns_types = {k: v for k, v in columns_types.items() if v is not None}
699701
_logger.debug("columns_types: %s", columns_types)
700702
return pa.schema(fields=columns_types)
701703

702704

705+
def _is_index_name(name: str, index: pd.Index) -> bool:
706+
if name in index.names:
707+
# named index level
708+
return True
709+
710+
if (match := re.match(r"__index_level_(?P<level>\d+)__", name)) is not None:
711+
# unnamed index level
712+
if len(index.names) > (level := int(match.group("level"))):
713+
return index.names[level] is None
714+
715+
return False
716+
717+
703718
def athena_types_from_pyarrow_schema(
704719
schema: pa.Schema,
705720
ignore_null: bool = False,

awswrangler/_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import numpy as np
3232
import pyarrow as pa
3333
from botocore.config import Config
34+
from packaging import version
3435

3536
import awswrangler.pandas as pd
3637
from awswrangler import _config, exceptions
@@ -893,7 +894,11 @@ def split_pandas_frame(df: pd.DataFrame, splits: int) -> list[pd.DataFrame]:
893894
@engine.dispatch_on_engine
894895
def table_refs_to_df(tables: list[pa.Table], kwargs: dict[str, Any]) -> pd.DataFrame:
895896
"""Build Pandas DataFrame from list of PyArrow tables."""
896-
return _table_to_df(pa.concat_tables(tables, promote=True), kwargs=kwargs)
897+
promote_kwargs: dict[str, bool | str] = {"promote": True}
898+
if version.parse(pa.__version__) >= version.parse("14.0.0"):
899+
promote_kwargs = {"promote_options": "default"}
900+
901+
return _table_to_df(pa.concat_tables(tables, **promote_kwargs), kwargs=kwargs)
897902

898903

899904
@engine.dispatch_on_engine

awswrangler/s3/_write_dataset.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,10 @@ def _to_partitions(
159159
inplace=True,
160160
)
161161
# Drop index levels if partitioning by index columns
162-
subgroup = subgroup.droplevel( # noqa: PLW2901
163-
level=[col for col in partition_cols if col in subgroup.index.names]
162+
subgroup.reset_index(
163+
level=[col for col in partition_cols if col in subgroup.index.names],
164+
drop=True,
165+
inplace=True,
164166
)
165167
prefix = _delete_objects(
166168
keys=keys,

tests/unit/test_s3_parquet.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,61 @@ def test_index_columns(path, use_threads, name, pandas):
506506
assert df[["c0"]].equals(df2)
507507

508508

509+
@pytest.mark.parametrize("index", [None, ["c0"], ["c0", "c1"]])
510+
def test_index_schema_validation(path, glue_database, glue_table, index):
511+
df = pd.DataFrame({"c0": [0, 1], "c1": [2, 3], "c2": [4, 5]}, dtype="Int64")
512+
513+
if index is not None:
514+
df = df.set_index(index)
515+
else:
516+
df.index = df.index.astype("Int64")
517+
518+
for _ in range(2):
519+
wr.s3.to_parquet(df, path, index=True, dataset=True, database=glue_database, table=glue_table)
520+
521+
df2 = wr.s3.read_parquet(path, validate_schema=True)
522+
assert_pandas_equals(pd.concat([df, df]), df2)
523+
524+
525+
@pytest.mark.modin_index
526+
@pytest.mark.parametrize("index", [["c0"], ["c0", "c1"]])
527+
@pytest.mark.parametrize("partition_cols", [["c0"], ["c0", "c1"]])
528+
def test_index_partition(path, glue_database, glue_table, index, partition_cols):
529+
df = pd.DataFrame({"c0": [0, 1], "c1": [2, 3], "c2": [4, 5]}, dtype="Int64")
530+
df = df.set_index(index)
531+
532+
for _ in range(2):
533+
wr.s3.to_parquet(
534+
df,
535+
path,
536+
index=True,
537+
dataset=True,
538+
partition_cols=partition_cols,
539+
database=glue_database,
540+
table=glue_table,
541+
)
542+
543+
df2 = wr.s3.read_parquet(path, dataset=True)
544+
545+
# partitioned index is not preserved, so reset unpartitioned index for recreation
546+
assert all(idx in df2.index.names for idx in [idx for idx in index if idx not in partition_cols])
547+
df2 = df2.reset_index()
548+
549+
# partition columns come back as categorical, so convert back
550+
for col in partition_cols:
551+
df2[col] = df2[col].astype("Int64")
552+
553+
# apply full index again
554+
df2 = df2.set_index(index)
555+
556+
assert_pandas_equals(
557+
# partitioned on index, so the data comes back sorted on the index
558+
pd.concat([df, df]).sort_index(),
559+
# need to reorder columns, because partition columns are appended
560+
df2[df.columns],
561+
)
562+
563+
509564
@pytest.mark.parametrize("use_threads", [True, False, 2])
510565
@pytest.mark.parametrize("name", [None, "foo"])
511566
@pytest.mark.parametrize("pandas", [True, False])

0 commit comments

Comments
 (0)