1414
1515from awswrangler import _data_types , _utils , catalog , exceptions
1616from awswrangler ._config import apply_configs
17- from awswrangler .s3 ._delete import delete_objects
18- from awswrangler .s3 ._describe import size_objects
19- from awswrangler .s3 ._list import does_object_exist
2017from awswrangler .s3 ._read_parquet import _read_parquet_metadata
2118from awswrangler .s3 ._write import _COMPRESSION_2_EXT , _apply_dtype , _sanitize , _validate_args
19+ from awswrangler .s3 ._write_concurrent import _WriteProxy
2220from awswrangler .s3 ._write_dataset import _to_dataset
2321
2422_logger : logging .Logger = logging .getLogger (__name__ )
2523
2624
27- def _to_parquet_file (
25+ def _get_file_path (file_counter : int , file_path : str ) -> str :
26+ slash_index : int = file_path .rfind ("/" )
27+ dot_index : int = file_path .find ("." , slash_index )
28+ file_index : str = "_" + str (file_counter )
29+ if dot_index == - 1 :
30+ file_path = file_path + file_index
31+ else :
32+ file_path = file_path [:dot_index ] + file_index + file_path [dot_index :]
33+ return file_path
34+
35+
36+ def _get_fs (
37+ boto3_session : Optional [boto3 .Session ], s3_additional_kwargs : Optional [Dict [str , str ]]
38+ ) -> s3fs .S3FileSystem :
39+ return _utils .get_fs (
40+ s3fs_block_size = 33_554_432 , # 32 MB (32 * 2**20)
41+ session = boto3_session ,
42+ s3_additional_kwargs = s3_additional_kwargs ,
43+ )
44+
45+
46+ def _new_writer (
47+ file_path : str , fs : s3fs .S3FileSystem , compression : Optional [str ], schema : pa .Schema
48+ ) -> pyarrow .parquet .ParquetWriter :
49+ return pyarrow .parquet .ParquetWriter (
50+ where = file_path ,
51+ write_statistics = True ,
52+ use_dictionary = True ,
53+ filesystem = fs ,
54+ coerce_timestamps = "ms" ,
55+ compression = compression ,
56+ flavor = "spark" ,
57+ schema = schema ,
58+ )
59+
60+
61+ def _write_chunk (
62+ file_path : str ,
63+ boto3_session : Optional [boto3 .Session ],
64+ s3_additional_kwargs : Optional [Dict [str , str ]],
65+ compression : Optional [str ],
66+ table : pa .Table ,
67+ offset : int ,
68+ chunk_size : int ,
69+ ):
70+ fs = _get_fs (boto3_session = boto3_session , s3_additional_kwargs = s3_additional_kwargs )
71+ with _new_writer (file_path = file_path , fs = fs , compression = compression , schema = table .schema ) as writer :
72+ writer .write_table (table .slice (offset , chunk_size ))
73+ return [file_path ]
74+
75+
76+ def _to_parquet_chunked (
77+ file_path : str ,
78+ boto3_session : Optional [boto3 .Session ],
79+ s3_additional_kwargs : Optional [Dict [str , str ]],
80+ compression : Optional [str ],
81+ table : pa .Table ,
82+ max_rows_by_file : int ,
83+ num_of_rows : int ,
84+ cpus : int ,
85+ ) -> List [str ]:
86+ chunks : int = math .ceil (num_of_rows / max_rows_by_file )
87+ use_threads : bool = cpus > 1
88+ proxy : _WriteProxy = _WriteProxy (use_threads = use_threads )
89+ for chunk in range (chunks ):
90+ offset : int = chunk * max_rows_by_file
91+ write_path : str = _get_file_path (chunk , file_path )
92+ proxy .write (
93+ func = _write_chunk ,
94+ file_path = write_path ,
95+ boto3_session = boto3_session ,
96+ s3_additional_kwargs = s3_additional_kwargs ,
97+ compression = compression ,
98+ table = table ,
99+ offset = offset ,
100+ chunk_size = max_rows_by_file ,
101+ )
102+ return proxy .close () # blocking
103+
104+
105+ def _to_parquet (
28106 df : pd .DataFrame ,
29107 schema : pa .Schema ,
30108 index : bool ,
@@ -36,16 +114,15 @@ def _to_parquet_file(
36114 s3_additional_kwargs : Optional [Dict [str , str ]],
37115 path : Optional [str ] = None ,
38116 path_root : Optional [str ] = None ,
39- max_file_size : Optional [int ] = 0 ,
40- ) -> str :
117+ max_rows_by_file : Optional [int ] = 0 ,
118+ ) -> List [ str ] :
41119 if path is None and path_root is not None :
42120 file_path : str = f"{ path_root } { uuid .uuid4 ().hex } { compression_ext } .parquet"
43121 elif path is not None and path_root is None :
44122 file_path = path
45123 else :
46124 raise RuntimeError ("path and path_root received at the same time." )
47125 _logger .debug ("file_path: %s" , file_path )
48- write_path = file_path
49126 table : pa .Table = pyarrow .Table .from_pandas (df = df , schema = schema , nthreads = cpus , preserve_index = index , safe = True )
50127 for col_name , col_type in dtype .items ():
51128 if col_name in table .column_names :
@@ -54,64 +131,23 @@ def _to_parquet_file(
54131 field = pa .field (name = col_name , type = pyarrow_dtype )
55132 table = table .set_column (col_index , field , table .column (col_name ).cast (pyarrow_dtype ))
56133 _logger .debug ("Casting column %s (%s) to %s (%s)" , col_name , col_index , col_type , pyarrow_dtype )
57- fs : s3fs .S3FileSystem = _utils .get_fs (
58- s3fs_block_size = 33_554_432 ,
59- session = boto3_session ,
60- s3_additional_kwargs = s3_additional_kwargs , # 32 MB (32 * 2**20)
61- )
62-
63- file_counter , writer , chunks , chunk_size = 1 , None , 1 , df .shape [0 ]
64- if max_file_size is not None and max_file_size > 0 :
65- chunk_size = int ((max_file_size * df .shape [0 ]) / table .nbytes )
66- chunks = math .ceil (df .shape [0 ] / chunk_size )
67-
68- for chunk in range (chunks ):
69- offset = chunk * chunk_size
70-
71- if writer is None :
72- writer = pyarrow .parquet .ParquetWriter (
73- where = write_path ,
74- write_statistics = True ,
75- use_dictionary = True ,
76- filesystem = fs ,
77- coerce_timestamps = "ms" ,
78- compression = compression ,
79- flavor = "spark" ,
80- schema = table .schema ,
81- )
82- # handle the case of overwriting an existing file
83- if does_object_exist (write_path ):
84- delete_objects ([write_path ])
85-
86- writer .write_table (table .slice (offset , chunk_size ))
87-
88- if max_file_size == 0 or max_file_size is None :
89- continue
90-
91- file_size = writer .file_handle .buffer .__sizeof__ ()
92- if does_object_exist (write_path ):
93- file_size += size_objects ([write_path ])[write_path ]
94-
95- if file_size >= max_file_size :
96- write_path = __get_file_path (file_counter , file_path )
97- file_counter += 1
98- writer .close ()
99- writer = None
100-
101- if writer is not None :
102- writer .close ()
103-
104- return file_path
105-
106-
107- def __get_file_path (file_counter , file_path ):
108- dot_index = file_path .rfind ("." )
109- file_index = "-" + str (file_counter )
110- if dot_index == - 1 :
111- file_path = file_path + file_index
134+ if max_rows_by_file is not None and max_rows_by_file > 0 :
135+ paths : List [str ] = _to_parquet_chunked (
136+ file_path = file_path ,
137+ boto3_session = boto3_session ,
138+ s3_additional_kwargs = s3_additional_kwargs ,
139+ compression = compression ,
140+ table = table ,
141+ max_rows_by_file = max_rows_by_file ,
142+ num_of_rows = df .shape [0 ],
143+ cpus = cpus ,
144+ )
112145 else :
113- file_path = file_path [:dot_index ] + file_index + file_path [dot_index :]
114- return file_path
146+ fs = _get_fs (boto3_session = boto3_session , s3_additional_kwargs = s3_additional_kwargs )
147+ with _new_writer (file_path = file_path , fs = fs , compression = compression , schema = table .schema ) as writer :
148+ writer .write_table (table )
149+ paths = [file_path ]
150+ return paths
115151
116152
117153@apply_configs
@@ -120,6 +156,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
120156 path : str ,
121157 index : bool = False ,
122158 compression : Optional [str ] = "snappy" ,
159+ max_rows_by_file : Optional [int ] = None ,
123160 use_threads : bool = True ,
124161 boto3_session : Optional [boto3 .Session ] = None ,
125162 s3_additional_kwargs : Optional [Dict [str , str ]] = None ,
@@ -142,7 +179,6 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
142179 projection_values : Optional [Dict [str , str ]] = None ,
143180 projection_intervals : Optional [Dict [str , str ]] = None ,
144181 projection_digits : Optional [Dict [str , str ]] = None ,
145- max_file_size : Optional [int ] = 0 ,
146182 catalog_id : Optional [str ] = None ,
147183) -> Dict [str , Union [List [str ], Dict [str , List [str ]]]]:
148184 """Write Parquet file or dataset on Amazon S3.
@@ -175,6 +211,10 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
175211 True to store the DataFrame index in file, otherwise False to ignore it.
176212 compression: str, optional
177213 Compression style (``None``, ``snappy``, ``gzip``).
214+ max_rows_by_file : int
215+ Max number of rows in each file.
216+ Default is None i.e. dont split the files.
217+ (e.g. 33554432, 268435456)
178218 use_threads : bool
179219 True to enable concurrent requests, False to disable multiple threads.
180220 If enabled os.cpu_count() will be used as the max number of threads.
@@ -245,10 +285,6 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
245285 Dictionary of partitions names and Athena projections digits.
246286 https://docs.aws.amazon.com/athena/latest/ug/partition-projection-supported-types.html
247287 (e.g. {'col_name': '1', 'col2_name': '2'})
248- max_file_size : int
249- If the file size exceeds the specified size in bytes, another file is created
250- Default is 0 i.e. dont split the files
251- (e.g. 33554432 ,268435456,0)
252288 catalog_id : str, optional
253289 The ID of the Data Catalog from which to retrieve Databases.
254290 If none is provided, the AWS account ID is used by default.
@@ -401,24 +437,22 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
401437 _logger .debug ("schema: \n %s" , schema )
402438
403439 if dataset is False :
404- paths = [
405- _to_parquet_file (
406- df = df ,
407- path = path ,
408- schema = schema ,
409- index = index ,
410- cpus = cpus ,
411- compression = compression ,
412- compression_ext = compression_ext ,
413- boto3_session = session ,
414- s3_additional_kwargs = s3_additional_kwargs ,
415- dtype = dtype ,
416- max_file_size = max_file_size ,
417- )
418- ]
440+ paths = _to_parquet (
441+ df = df ,
442+ path = path ,
443+ schema = schema ,
444+ index = index ,
445+ cpus = cpus ,
446+ compression = compression ,
447+ compression_ext = compression_ext ,
448+ boto3_session = session ,
449+ s3_additional_kwargs = s3_additional_kwargs ,
450+ dtype = dtype ,
451+ max_rows_by_file = max_rows_by_file ,
452+ )
419453 else :
420454 paths , partitions_values = _to_dataset (
421- func = _to_parquet_file ,
455+ func = _to_parquet ,
422456 concurrent_partitioning = concurrent_partitioning ,
423457 df = df ,
424458 path_root = path ,
@@ -433,6 +467,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
433467 boto3_session = session ,
434468 s3_additional_kwargs = s3_additional_kwargs ,
435469 schema = schema ,
470+ max_rows_by_file = max_rows_by_file ,
436471 )
437472 if (database is not None ) and (table is not None ):
438473 columns_types , partitions_types = _data_types .athena_types_from_pandas_partitioned (
0 commit comments