Skip to content

Commit 4edc97d

Browse files
authored
Add schema evolution to s3.to_csv (#799)
1 parent 5813e74 commit 4edc97d

File tree

4 files changed

+46
-22
lines changed

4 files changed

+46
-22
lines changed

awswrangler/s3/_write.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,20 @@ def _sanitize(
9999
dtype = {catalog.sanitize_column_name(k): v.lower() for k, v in dtype.items()}
100100
_utils.check_duplicated_columns(df=df)
101101
return df, dtype, partition_cols
102+
103+
104+
def _check_schema_changes(columns_types: Dict[str, str], table_input: Optional[Dict[str, Any]], mode: str) -> None:
105+
if (table_input is not None) and (mode in ("append", "overwrite_partitions")):
106+
catalog_cols: Dict[str, str] = {x["Name"]: x["Type"] for x in table_input["StorageDescriptor"]["Columns"]}
107+
for c, t in columns_types.items():
108+
if c not in catalog_cols:
109+
raise exceptions.InvalidArgumentValue(
110+
f"Schema change detected: New column {c} with type {t}. "
111+
"Please pass schema_evolution=True to allow new columns "
112+
"behaviour."
113+
)
114+
if t != catalog_cols[c]: # Data type change detected!
115+
raise exceptions.InvalidArgumentValue(
116+
f"Schema change detected: Data type change on column {c} "
117+
f"(Old type: {catalog_cols[c]} / New type {t})."
118+
)

awswrangler/s3/_write_parquet.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,13 @@
1717
from awswrangler.s3._delete import delete_objects
1818
from awswrangler.s3._fs import open_s3_object
1919
from awswrangler.s3._read_parquet import _read_parquet_metadata
20-
from awswrangler.s3._write import _COMPRESSION_2_EXT, _apply_dtype, _sanitize, _validate_args
20+
from awswrangler.s3._write import _COMPRESSION_2_EXT, _apply_dtype, _check_schema_changes, _sanitize, _validate_args
2121
from awswrangler.s3._write_concurrent import _WriteProxy
2222
from awswrangler.s3._write_dataset import _to_dataset
2323

2424
_logger: logging.Logger = logging.getLogger(__name__)
2525

2626

27-
def _check_schema_changes(columns_types: Dict[str, str], table_input: Optional[Dict[str, Any]], mode: str) -> None:
28-
if (table_input is not None) and (mode in ("append", "overwrite_partitions")):
29-
catalog_cols: Dict[str, str] = {x["Name"]: x["Type"] for x in table_input["StorageDescriptor"]["Columns"]}
30-
for c, t in columns_types.items():
31-
if c not in catalog_cols:
32-
raise exceptions.InvalidArgumentValue(
33-
f"Schema change detected: New column {c} with type {t}. "
34-
"Please pass schema_evolution=True to allow new columns "
35-
"behaviour."
36-
)
37-
if t != catalog_cols[c]: # Data type change detected!
38-
raise exceptions.InvalidArgumentValue(
39-
f"Schema change detected: Data type change on column {c} "
40-
f"(Old type: {catalog_cols[c]} / New type {t})."
41-
)
42-
43-
4427
def _get_file_path(file_counter: int, file_path: str) -> str:
4528
slash_index: int = file_path.rfind("/")
4629
dot_index: int = file_path.find(".", slash_index)

awswrangler/s3/_write_text.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from awswrangler._config import apply_configs
1515
from awswrangler.s3._delete import delete_objects
1616
from awswrangler.s3._fs import open_s3_object
17-
from awswrangler.s3._write import _COMPRESSION_2_EXT, _apply_dtype, _sanitize, _validate_args
17+
from awswrangler.s3._write import _COMPRESSION_2_EXT, _apply_dtype, _check_schema_changes, _sanitize, _validate_args
1818
from awswrangler.s3._write_dataset import _to_dataset
1919

2020
_logger: logging.Logger = logging.getLogger(__name__)
@@ -87,6 +87,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
8787
concurrent_partitioning: bool = False,
8888
mode: Optional[str] = None,
8989
catalog_versioning: bool = False,
90+
schema_evolution: bool = False,
9091
database: Optional[str] = None,
9192
table: Optional[str] = None,
9293
dtype: Optional[Dict[str, str]] = None,
@@ -182,6 +183,11 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
182183
https://aws-data-wrangler.readthedocs.io/en/2.9.0/stubs/awswrangler.s3.to_parquet.html#awswrangler.s3.to_parquet
183184
catalog_versioning : bool
184185
If True and `mode="overwrite"`, creates an archived version of the table catalog before updating it.
186+
schema_evolution : bool
187+
If True allows schema evolution (new or missing columns), otherwise a exception will be raised.
188+
(Only considered if dataset=True and mode in ("append", "overwrite_partitions"))
189+
Related tutorial:
190+
https://aws-data-wrangler.readthedocs.io/en/2.9.0/tutorials/014%20-%20Schema%20Evolution.html
185191
database : str, optional
186192
Glue/Athena catalog: Database name.
187193
table : str, optional
@@ -474,6 +480,16 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
474480
pd_kwargs.pop("compression", None)
475481

476482
df = df[columns] if columns else df
483+
484+
columns_types: Dict[str, str] = {}
485+
partitions_types: Dict[str, str] = {}
486+
if (database is not None) and (table is not None):
487+
columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned(
488+
df=df, index=index, partition_cols=partition_cols, dtype=dtype, index_left=True
489+
)
490+
if schema_evolution is False:
491+
_check_schema_changes(columns_types=columns_types, table_input=catalog_table_input, mode=mode)
492+
477493
paths, partitions_values = _to_dataset(
478494
func=_to_text,
479495
concurrent_partitioning=concurrent_partitioning,
@@ -498,9 +514,6 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
498514
)
499515
if database and table:
500516
try:
501-
columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned(
502-
df=df, index=index, partition_cols=partition_cols, dtype=dtype, index_left=True
503-
)
504517
serde_info: Dict[str, Any] = {}
505518
if catalog_table_input:
506519
serde_info = catalog_table_input["StorageDescriptor"]["SerdeInfo"]

tests/test_s3_text.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,3 +331,14 @@ def test_read_csv_versioned(path) -> None:
331331
df_temp = wr.s3.read_csv(path_file, version_id=version_id)
332332
assert df_temp.equals(df)
333333
assert version_id == wr.s3.describe_objects(path=path_file, version_id=version_id)[path_file]["VersionId"]
334+
335+
336+
def test_to_csv_schema_evolution(path, glue_database, glue_table) -> None:
337+
path_file = f"{path}0.csv"
338+
df = pd.DataFrame({"c0": [0, 1, 2], "c1": [3, 4, 5]})
339+
wr.s3.to_csv(df=df, path=path_file, dataset=True, database=glue_database, table=glue_table)
340+
df["test"] = 1
341+
with pytest.raises(wr.exceptions.InvalidArgumentValue):
342+
wr.s3.to_csv(
343+
df=df, path=path_file, dataset=True, database=glue_database, table=glue_table, schema_evolution=True
344+
)

0 commit comments

Comments
 (0)