Skip to content

Commit fbb0c96

Browse files
committed
Add max_rows_by_file argument and more. #283
1 parent cc7694a commit fbb0c96

File tree

8 files changed

+190
-98
lines changed

8 files changed

+190
-98
lines changed

awswrangler/db.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,7 @@ def copy_to_redshift( # pylint: disable=too-many-arguments
451451
use_threads: bool = True,
452452
boto3_session: Optional[boto3.Session] = None,
453453
s3_additional_kwargs: Optional[Dict[str, str]] = None,
454+
max_rows_by_file: Optional[int] = 10_000_000,
454455
) -> None:
455456
"""Load Pandas DataFrame as a Table on Amazon Redshift using parquet files on S3 as stage.
456457
@@ -525,6 +526,11 @@ def copy_to_redshift( # pylint: disable=too-many-arguments
525526
s3_additional_kwargs:
526527
Forward to s3fs, useful for server side encryption
527528
https://s3fs.readthedocs.io/en/latest/#serverside-encryption
529+
max_rows_by_file : int
530+
Max number of rows in each file.
531+
Default is None i.e. dont split the files.
532+
(e.g. 33554432, 268435456)
533+
528534
Returns
529535
-------
530536
None
@@ -556,6 +562,7 @@ def copy_to_redshift( # pylint: disable=too-many-arguments
556562
use_threads=use_threads,
557563
boto3_session=session,
558564
s3_additional_kwargs=s3_additional_kwargs,
565+
max_rows_by_file=max_rows_by_file,
559566
)["paths"]
560567
s3.wait_objects_exist(paths=paths, use_threads=False, boto3_session=session)
561568
copy_files_to_redshift(

awswrangler/s3/_write_concurrent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@ def write(self, func: Callable, boto3_session: boto3.Session, **func_kwargs) ->
4444
)
4545
self._futures.append(future)
4646
else:
47-
self._results.append(func(boto3_session=boto3_session, **func_kwargs))
47+
self._results += func(boto3_session=boto3_session, **func_kwargs)
4848

4949
def close(self):
5050
"""Close the proxy."""
5151
if self._exec is not None:
5252
for future in concurrent.futures.as_completed(self._futures):
53-
self._results.append(future.result())
53+
self._results += future.result()
5454
self._exec.shutdown(wait=True)
5555
return self._results

awswrangler/s3/_write_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _to_dataset(
6464
# Writing
6565
partitions_values: Dict[str, List[str]] = {}
6666
if not partition_cols:
67-
paths: List[str] = [func(df=df, path_root=path_root, boto3_session=boto3_session, index=index, **func_kwargs)]
67+
paths: List[str] = func(df=df, path_root=path_root, boto3_session=boto3_session, index=index, **func_kwargs)
6868
else:
6969
paths, partitions_values = _to_partitions(
7070
func=func,

awswrangler/s3/_write_parquet.py

Lines changed: 120 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,95 @@
1414

1515
from awswrangler import _data_types, _utils, catalog, exceptions
1616
from awswrangler._config import apply_configs
17-
from awswrangler.s3._delete import delete_objects
18-
from awswrangler.s3._describe import size_objects
19-
from awswrangler.s3._list import does_object_exist
2017
from awswrangler.s3._read_parquet import _read_parquet_metadata
2118
from awswrangler.s3._write import _COMPRESSION_2_EXT, _apply_dtype, _sanitize, _validate_args
19+
from awswrangler.s3._write_concurrent import _WriteProxy
2220
from awswrangler.s3._write_dataset import _to_dataset
2321

2422
_logger: logging.Logger = logging.getLogger(__name__)
2523

2624

27-
def _to_parquet_file(
25+
def _get_file_path(file_counter: int, file_path: str) -> str:
26+
slash_index: int = file_path.rfind("/")
27+
dot_index: int = file_path.find(".", slash_index)
28+
file_index: str = "_" + str(file_counter)
29+
if dot_index == -1:
30+
file_path = file_path + file_index
31+
else:
32+
file_path = file_path[:dot_index] + file_index + file_path[dot_index:]
33+
return file_path
34+
35+
36+
def _get_fs(
37+
boto3_session: Optional[boto3.Session], s3_additional_kwargs: Optional[Dict[str, str]]
38+
) -> s3fs.S3FileSystem:
39+
return _utils.get_fs(
40+
s3fs_block_size=33_554_432, # 32 MB (32 * 2**20)
41+
session=boto3_session,
42+
s3_additional_kwargs=s3_additional_kwargs,
43+
)
44+
45+
46+
def _new_writer(
47+
file_path: str, fs: s3fs.S3FileSystem, compression: Optional[str], schema: pa.Schema
48+
) -> pyarrow.parquet.ParquetWriter:
49+
return pyarrow.parquet.ParquetWriter(
50+
where=file_path,
51+
write_statistics=True,
52+
use_dictionary=True,
53+
filesystem=fs,
54+
coerce_timestamps="ms",
55+
compression=compression,
56+
flavor="spark",
57+
schema=schema,
58+
)
59+
60+
61+
def _write_chunk(
62+
file_path: str,
63+
boto3_session: Optional[boto3.Session],
64+
s3_additional_kwargs: Optional[Dict[str, str]],
65+
compression: Optional[str],
66+
table: pa.Table,
67+
offset: int,
68+
chunk_size: int,
69+
):
70+
fs = _get_fs(boto3_session=boto3_session, s3_additional_kwargs=s3_additional_kwargs)
71+
with _new_writer(file_path=file_path, fs=fs, compression=compression, schema=table.schema) as writer:
72+
writer.write_table(table.slice(offset, chunk_size))
73+
return [file_path]
74+
75+
76+
def _to_parquet_chunked(
77+
file_path: str,
78+
boto3_session: Optional[boto3.Session],
79+
s3_additional_kwargs: Optional[Dict[str, str]],
80+
compression: Optional[str],
81+
table: pa.Table,
82+
max_rows_by_file: int,
83+
num_of_rows: int,
84+
cpus: int,
85+
) -> List[str]:
86+
chunks: int = math.ceil(num_of_rows / max_rows_by_file)
87+
use_threads: bool = cpus > 1
88+
proxy: _WriteProxy = _WriteProxy(use_threads=use_threads)
89+
for chunk in range(chunks):
90+
offset: int = chunk * max_rows_by_file
91+
write_path: str = _get_file_path(chunk, file_path)
92+
proxy.write(
93+
func=_write_chunk,
94+
file_path=write_path,
95+
boto3_session=boto3_session,
96+
s3_additional_kwargs=s3_additional_kwargs,
97+
compression=compression,
98+
table=table,
99+
offset=offset,
100+
chunk_size=max_rows_by_file,
101+
)
102+
return proxy.close() # blocking
103+
104+
105+
def _to_parquet(
28106
df: pd.DataFrame,
29107
schema: pa.Schema,
30108
index: bool,
@@ -36,16 +114,15 @@ def _to_parquet_file(
36114
s3_additional_kwargs: Optional[Dict[str, str]],
37115
path: Optional[str] = None,
38116
path_root: Optional[str] = None,
39-
max_file_size: Optional[int] = 0,
40-
) -> str:
117+
max_rows_by_file: Optional[int] = 0,
118+
) -> List[str]:
41119
if path is None and path_root is not None:
42120
file_path: str = f"{path_root}{uuid.uuid4().hex}{compression_ext}.parquet"
43121
elif path is not None and path_root is None:
44122
file_path = path
45123
else:
46124
raise RuntimeError("path and path_root received at the same time.")
47125
_logger.debug("file_path: %s", file_path)
48-
write_path = file_path
49126
table: pa.Table = pyarrow.Table.from_pandas(df=df, schema=schema, nthreads=cpus, preserve_index=index, safe=True)
50127
for col_name, col_type in dtype.items():
51128
if col_name in table.column_names:
@@ -54,64 +131,23 @@ def _to_parquet_file(
54131
field = pa.field(name=col_name, type=pyarrow_dtype)
55132
table = table.set_column(col_index, field, table.column(col_name).cast(pyarrow_dtype))
56133
_logger.debug("Casting column %s (%s) to %s (%s)", col_name, col_index, col_type, pyarrow_dtype)
57-
fs: s3fs.S3FileSystem = _utils.get_fs(
58-
s3fs_block_size=33_554_432,
59-
session=boto3_session,
60-
s3_additional_kwargs=s3_additional_kwargs, # 32 MB (32 * 2**20)
61-
)
62-
63-
file_counter, writer, chunks, chunk_size = 1, None, 1, df.shape[0]
64-
if max_file_size is not None and max_file_size > 0:
65-
chunk_size = int((max_file_size * df.shape[0]) / table.nbytes)
66-
chunks = math.ceil(df.shape[0] / chunk_size)
67-
68-
for chunk in range(chunks):
69-
offset = chunk * chunk_size
70-
71-
if writer is None:
72-
writer = pyarrow.parquet.ParquetWriter(
73-
where=write_path,
74-
write_statistics=True,
75-
use_dictionary=True,
76-
filesystem=fs,
77-
coerce_timestamps="ms",
78-
compression=compression,
79-
flavor="spark",
80-
schema=table.schema,
81-
)
82-
# handle the case of overwriting an existing file
83-
if does_object_exist(write_path):
84-
delete_objects([write_path])
85-
86-
writer.write_table(table.slice(offset, chunk_size))
87-
88-
if max_file_size == 0 or max_file_size is None:
89-
continue
90-
91-
file_size = writer.file_handle.buffer.__sizeof__()
92-
if does_object_exist(write_path):
93-
file_size += size_objects([write_path])[write_path]
94-
95-
if file_size >= max_file_size:
96-
write_path = __get_file_path(file_counter, file_path)
97-
file_counter += 1
98-
writer.close()
99-
writer = None
100-
101-
if writer is not None:
102-
writer.close()
103-
104-
return file_path
105-
106-
107-
def __get_file_path(file_counter, file_path):
108-
dot_index = file_path.rfind(".")
109-
file_index = "-" + str(file_counter)
110-
if dot_index == -1:
111-
file_path = file_path + file_index
134+
if max_rows_by_file is not None and max_rows_by_file > 0:
135+
paths: List[str] = _to_parquet_chunked(
136+
file_path=file_path,
137+
boto3_session=boto3_session,
138+
s3_additional_kwargs=s3_additional_kwargs,
139+
compression=compression,
140+
table=table,
141+
max_rows_by_file=max_rows_by_file,
142+
num_of_rows=df.shape[0],
143+
cpus=cpus,
144+
)
112145
else:
113-
file_path = file_path[:dot_index] + file_index + file_path[dot_index:]
114-
return file_path
146+
fs = _get_fs(boto3_session=boto3_session, s3_additional_kwargs=s3_additional_kwargs)
147+
with _new_writer(file_path=file_path, fs=fs, compression=compression, schema=table.schema) as writer:
148+
writer.write_table(table)
149+
paths = [file_path]
150+
return paths
115151

116152

117153
@apply_configs
@@ -120,6 +156,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
120156
path: str,
121157
index: bool = False,
122158
compression: Optional[str] = "snappy",
159+
max_rows_by_file: Optional[int] = None,
123160
use_threads: bool = True,
124161
boto3_session: Optional[boto3.Session] = None,
125162
s3_additional_kwargs: Optional[Dict[str, str]] = None,
@@ -142,7 +179,6 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
142179
projection_values: Optional[Dict[str, str]] = None,
143180
projection_intervals: Optional[Dict[str, str]] = None,
144181
projection_digits: Optional[Dict[str, str]] = None,
145-
max_file_size: Optional[int] = 0,
146182
catalog_id: Optional[str] = None,
147183
) -> Dict[str, Union[List[str], Dict[str, List[str]]]]:
148184
"""Write Parquet file or dataset on Amazon S3.
@@ -175,6 +211,10 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
175211
True to store the DataFrame index in file, otherwise False to ignore it.
176212
compression: str, optional
177213
Compression style (``None``, ``snappy``, ``gzip``).
214+
max_rows_by_file : int
215+
Max number of rows in each file.
216+
Default is None i.e. dont split the files.
217+
(e.g. 33554432, 268435456)
178218
use_threads : bool
179219
True to enable concurrent requests, False to disable multiple threads.
180220
If enabled os.cpu_count() will be used as the max number of threads.
@@ -245,10 +285,6 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
245285
Dictionary of partitions names and Athena projections digits.
246286
https://docs.aws.amazon.com/athena/latest/ug/partition-projection-supported-types.html
247287
(e.g. {'col_name': '1', 'col2_name': '2'})
248-
max_file_size : int
249-
If the file size exceeds the specified size in bytes, another file is created
250-
Default is 0 i.e. dont split the files
251-
(e.g. 33554432 ,268435456,0)
252288
catalog_id : str, optional
253289
The ID of the Data Catalog from which to retrieve Databases.
254290
If none is provided, the AWS account ID is used by default.
@@ -401,24 +437,22 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
401437
_logger.debug("schema: \n%s", schema)
402438

403439
if dataset is False:
404-
paths = [
405-
_to_parquet_file(
406-
df=df,
407-
path=path,
408-
schema=schema,
409-
index=index,
410-
cpus=cpus,
411-
compression=compression,
412-
compression_ext=compression_ext,
413-
boto3_session=session,
414-
s3_additional_kwargs=s3_additional_kwargs,
415-
dtype=dtype,
416-
max_file_size=max_file_size,
417-
)
418-
]
440+
paths = _to_parquet(
441+
df=df,
442+
path=path,
443+
schema=schema,
444+
index=index,
445+
cpus=cpus,
446+
compression=compression,
447+
compression_ext=compression_ext,
448+
boto3_session=session,
449+
s3_additional_kwargs=s3_additional_kwargs,
450+
dtype=dtype,
451+
max_rows_by_file=max_rows_by_file,
452+
)
419453
else:
420454
paths, partitions_values = _to_dataset(
421-
func=_to_parquet_file,
455+
func=_to_parquet,
422456
concurrent_partitioning=concurrent_partitioning,
423457
df=df,
424458
path_root=path,
@@ -433,6 +467,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
433467
boto3_session=session,
434468
s3_additional_kwargs=s3_additional_kwargs,
435469
schema=schema,
470+
max_rows_by_file=max_rows_by_file,
436471
)
437472
if (database is not None) and (table is not None):
438473
columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned(

awswrangler/s3/_write_text.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def _to_text(
2525
path: Optional[str] = None,
2626
path_root: Optional[str] = None,
2727
**pandas_kwargs,
28-
) -> str:
28+
) -> List[str]:
2929
if df.empty is True:
3030
raise exceptions.EmptyDataFrame()
3131
if path is None and path_root is not None:
@@ -47,7 +47,7 @@ def _to_text(
4747
df.to_csv(f, **pandas_kwargs)
4848
elif file_format == "json":
4949
df.to_json(f, **pandas_kwargs)
50-
return file_path
50+
return [file_path]
5151

5252

5353
@apply_configs

tests/test_athena_parquet.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,36 @@ def test_parquet_catalog(path, path2, glue_table, glue_table2, glue_database):
5555
assert wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table2) is True
5656

5757

58+
@pytest.mark.parametrize("use_threads", [True, False])
59+
@pytest.mark.parametrize("max_rows_by_file", [None, 0, 40, 250, 1000])
60+
@pytest.mark.parametrize("partition_cols", [None, ["par0"], ["par0", "par1"]])
61+
def test_file_size(path, glue_table, glue_database, use_threads, max_rows_by_file, partition_cols):
62+
df = get_df_list()
63+
df = pd.concat([df for _ in range(100)])
64+
paths = wr.s3.to_parquet(
65+
df=df,
66+
path=path,
67+
index=False,
68+
dataset=True,
69+
database=glue_database,
70+
table=glue_table,
71+
max_rows_by_file=max_rows_by_file,
72+
use_threads=use_threads,
73+
partition_cols=partition_cols,
74+
)["paths"]
75+
if max_rows_by_file is not None and max_rows_by_file > 0:
76+
assert len(paths) >= math.floor(300 / max_rows_by_file)
77+
wr.s3.wait_objects_exist(paths, use_threads=use_threads)
78+
df2 = wr.s3.read_parquet(path=path, dataset=True, use_threads=use_threads)
79+
ensure_data_types(df2, has_list=True)
80+
assert df2.shape == (300, 19)
81+
assert df.iint8.sum() == df2.iint8.sum()
82+
df2 = wr.athena.read_sql_table(database=glue_database, table=glue_table, use_threads=use_threads)
83+
ensure_data_types(df2, has_list=True)
84+
assert df2.shape == (300, 19)
85+
assert df.iint8.sum() == df2.iint8.sum()
86+
87+
5888
def test_parquet_catalog_duplicated(path, glue_table, glue_database):
5989
df = pd.DataFrame({"A": [1], "a": [1]})
6090
with pytest.raises(wr.exceptions.InvalidDataFrame):

0 commit comments

Comments
 (0)