Skip to content

Commit cc7694a

Browse files
Add an argument max_file_size in bytes which ensures that the files w… (#341)
* Add an argument max_file_size in bytes which ensures that the files written do not cross the file size specified . Refer #283 for further details * Address formatting and type failures Co-authored-by: Igor Tavares <[email protected]>
1 parent 282dac7 commit cc7694a

File tree

2 files changed

+80
-11
lines changed

2 files changed

+80
-11
lines changed

awswrangler/s3/_write_parquet.py

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Amazon PARQUET S3 Parquet Write Module (PRIVATE)."""
22

33
import logging
4+
import math
45
import uuid
56
from typing import Any, Dict, List, Optional, Tuple, Union
67

@@ -13,6 +14,9 @@
1314

1415
from awswrangler import _data_types, _utils, catalog, exceptions
1516
from awswrangler._config import apply_configs
17+
from awswrangler.s3._delete import delete_objects
18+
from awswrangler.s3._describe import size_objects
19+
from awswrangler.s3._list import does_object_exist
1620
from awswrangler.s3._read_parquet import _read_parquet_metadata
1721
from awswrangler.s3._write import _COMPRESSION_2_EXT, _apply_dtype, _sanitize, _validate_args
1822
from awswrangler.s3._write_dataset import _to_dataset
@@ -32,6 +36,7 @@ def _to_parquet_file(
3236
s3_additional_kwargs: Optional[Dict[str, str]],
3337
path: Optional[str] = None,
3438
path_root: Optional[str] = None,
39+
max_file_size: Optional[int] = 0,
3540
) -> str:
3641
if path is None and path_root is not None:
3742
file_path: str = f"{path_root}{uuid.uuid4().hex}{compression_ext}.parquet"
@@ -40,6 +45,7 @@ def _to_parquet_file(
4045
else:
4146
raise RuntimeError("path and path_root received at the same time.")
4247
_logger.debug("file_path: %s", file_path)
48+
write_path = file_path
4349
table: pa.Table = pyarrow.Table.from_pandas(df=df, schema=schema, nthreads=cpus, preserve_index=index, safe=True)
4450
for col_name, col_type in dtype.items():
4551
if col_name in table.column_names:
@@ -53,17 +59,58 @@ def _to_parquet_file(
5359
session=boto3_session,
5460
s3_additional_kwargs=s3_additional_kwargs, # 32 MB (32 * 2**20)
5561
)
56-
with pyarrow.parquet.ParquetWriter(
57-
where=file_path,
58-
write_statistics=True,
59-
use_dictionary=True,
60-
filesystem=fs,
61-
coerce_timestamps="ms",
62-
compression=compression,
63-
flavor="spark",
64-
schema=table.schema,
65-
) as writer:
66-
writer.write_table(table)
62+
63+
file_counter, writer, chunks, chunk_size = 1, None, 1, df.shape[0]
64+
if max_file_size is not None and max_file_size > 0:
65+
chunk_size = int((max_file_size * df.shape[0]) / table.nbytes)
66+
chunks = math.ceil(df.shape[0] / chunk_size)
67+
68+
for chunk in range(chunks):
69+
offset = chunk * chunk_size
70+
71+
if writer is None:
72+
writer = pyarrow.parquet.ParquetWriter(
73+
where=write_path,
74+
write_statistics=True,
75+
use_dictionary=True,
76+
filesystem=fs,
77+
coerce_timestamps="ms",
78+
compression=compression,
79+
flavor="spark",
80+
schema=table.schema,
81+
)
82+
# handle the case of overwriting an existing file
83+
if does_object_exist(write_path):
84+
delete_objects([write_path])
85+
86+
writer.write_table(table.slice(offset, chunk_size))
87+
88+
if max_file_size == 0 or max_file_size is None:
89+
continue
90+
91+
file_size = writer.file_handle.buffer.__sizeof__()
92+
if does_object_exist(write_path):
93+
file_size += size_objects([write_path])[write_path]
94+
95+
if file_size >= max_file_size:
96+
write_path = __get_file_path(file_counter, file_path)
97+
file_counter += 1
98+
writer.close()
99+
writer = None
100+
101+
if writer is not None:
102+
writer.close()
103+
104+
return file_path
105+
106+
107+
def __get_file_path(file_counter, file_path):
108+
dot_index = file_path.rfind(".")
109+
file_index = "-" + str(file_counter)
110+
if dot_index == -1:
111+
file_path = file_path + file_index
112+
else:
113+
file_path = file_path[:dot_index] + file_index + file_path[dot_index:]
67114
return file_path
68115

69116

@@ -95,6 +142,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
95142
projection_values: Optional[Dict[str, str]] = None,
96143
projection_intervals: Optional[Dict[str, str]] = None,
97144
projection_digits: Optional[Dict[str, str]] = None,
145+
max_file_size: Optional[int] = 0,
98146
catalog_id: Optional[str] = None,
99147
) -> Dict[str, Union[List[str], Dict[str, List[str]]]]:
100148
"""Write Parquet file or dataset on Amazon S3.
@@ -197,6 +245,10 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
197245
Dictionary of partitions names and Athena projections digits.
198246
https://docs.aws.amazon.com/athena/latest/ug/partition-projection-supported-types.html
199247
(e.g. {'col_name': '1', 'col2_name': '2'})
248+
max_file_size : int
249+
If the file size exceeds the specified size in bytes, another file is created
250+
Default is 0 i.e. dont split the files
251+
(e.g. 33554432 ,268435456,0)
200252
catalog_id : str, optional
201253
The ID of the Data Catalog from which to retrieve Databases.
202254
If none is provided, the AWS account ID is used by default.
@@ -361,6 +413,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
361413
boto3_session=session,
362414
s3_additional_kwargs=s3_additional_kwargs,
363415
dtype=dtype,
416+
max_file_size=max_file_size,
364417
)
365418
]
366419
else:

tests/test_moto.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,22 @@ def test_parquet(moto_s3):
309309
assert df.shape == (3, 19)
310310

311311

312+
def test_parquet_with_size(moto_s3):
313+
path = "s3://bucket/test.parquet"
314+
df = get_df_list()
315+
for i in range(20):
316+
df = pd.concat([df, get_df_list()])
317+
wr.s3.to_parquet(df=df, path=path, index=False, dataset=False, max_file_size=1 * 2 ** 10)
318+
df = wr.s3.read_parquet(path="s3://bucket/", dataset=False)
319+
ensure_data_types(df, has_list=True)
320+
assert df.shape == (63, 19)
321+
file_objects = wr.s3.list_objects(path="s3://bucket/")
322+
assert len(file_objects) == 9
323+
for i in range(7):
324+
assert f"s3://bucket/test-{i+1}.parquet" in file_objects
325+
assert "s3://bucket/test.parquet" in file_objects
326+
327+
312328
def test_s3_delete_object_success(moto_s3):
313329
path = "s3://bucket/test.parquet"
314330
wr.s3.to_parquet(df=get_df_list(), path=path, index=False, dataset=True, partition_cols=["par0", "par1"])

0 commit comments

Comments
 (0)