Skip to content

Commit cb7fb67

Browse files
committed
Updated implementation of “overwrite_files”
1 parent e23e1a8 commit cb7fb67

File tree

3 files changed

+25
-15
lines changed

3 files changed

+25
-15
lines changed

awswrangler/s3/_write.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import uuid
77
from abc import ABC, abstractmethod
88
from enum import Enum
9-
from typing import TYPE_CHECKING, Any, Callable, NamedTuple
9+
from typing import TYPE_CHECKING, Any, Callable, NamedTuple, Tuple
1010

1111
import boto3
1212
import pandas as pd
@@ -35,6 +35,21 @@
3535
}
3636

3737

38+
def _compose_filename_prefix_for_mode(*, mode: str, filename_prefix: str = None) -> Tuple[str, str]:
39+
if mode == "overwrite_files":
40+
if filename_prefix is None:
41+
filename_prefix = "part"
42+
random_filename_suffix = ""
43+
mode = "append"
44+
else:
45+
random_filename_suffix = uuid.uuid4().hex
46+
47+
if filename_prefix is None:
48+
filename_prefix = ""
49+
filename_prefix = filename_prefix + random_filename_suffix
50+
return filename_prefix, mode
51+
52+
3853
def _extract_dtypes_from_table_input(table_input: dict[str, Any]) -> dict[str, str]:
3954
dtypes: dict[str, str] = {}
4055
for col in table_input["StorageDescriptor"]["Columns"]:
@@ -71,6 +86,7 @@ def _validate_args(
7186
parameters: dict[str, str] | None,
7287
columns_comments: dict[str, str] | None,
7388
columns_parameters: dict[str, dict[str, str]] | None,
89+
max_rows_by_file: int | None,
7490
execution_engine: Enum,
7591
) -> None:
7692
if df.empty is True:
@@ -88,6 +104,10 @@ def _validate_args(
88104
raise exceptions.InvalidArgumentCombination("Please, pass dataset=True to be able to use bucketing_info.")
89105
if mode is not None:
90106
raise exceptions.InvalidArgumentCombination("Please pass dataset=True to be able to use mode.")
107+
if mode == "overwrite_files" and (max_rows_by_file or bucketing_info):
108+
raise exceptions.InvalidArgumentValue(
109+
"When mode is set to 'overwrite_files', the "
110+
"`max_rows_by_file` and `bucketing_info` arguments cannot be set.")
91111
if any(arg is not None for arg in (table, description, parameters, columns_comments, columns_parameters)):
92112
raise exceptions.InvalidArgumentCombination(
93113
"Please pass dataset=True to be able to use any one of these "
@@ -278,20 +298,8 @@ def write( # noqa: PLR0913
278298
dtype = dtype if dtype else {}
279299
partitions_values: dict[str, list[str]] = {}
280300

281-
if mode == "overwrite_files":
282-
assert max_rows_by_file in [None, 0]
283-
284-
if filename_prefix is None:
285-
filename_prefix = "part"
286-
random_filename_suffix = ""
287-
mode = "append"
288-
else:
289-
random_filename_suffix = uuid.uuid4().hex
290-
291-
if filename_prefix is None:
292-
filename_prefix = ""
293-
filename_prefix = filename_prefix + random_filename_suffix
294-
301+
mode, filename_prefix = _compose_filename_prefix_for_mode(
302+
mode=mode, filename_prefix=filename_prefix)
295303
mode = "append" if mode is None else mode
296304
cpus: int = _utils.ensure_cpu_count(use_threads=use_threads)
297305
s3_client = _utils.client(service_name="s3", session=boto3_session)

awswrangler/s3/_write_orc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,7 @@ def to_orc(
645645
parameters=parameters,
646646
columns_comments=columns_comments,
647647
columns_parameters=columns_parameters,
648+
max_rows_by_file=max_rows_by_file,
648649
execution_engine=engine.get(),
649650
)
650651

awswrangler/s3/_write_parquet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,7 @@ def to_parquet(
703703
parameters=parameters,
704704
columns_comments=columns_comments,
705705
columns_parameters=columns_parameters,
706+
max_rows_by_file=max_rows_by_file,
706707
execution_engine=engine.get(),
707708
)
708709

0 commit comments

Comments
 (0)