Skip to content

Commit 2865c85

Browse files
authored
(feat) add distributed s3 write parquet (#1526)
* Add distributed s3 write parquet * add type mappings to avoid inference * Refactoring - separate distributed write_parquet imlementation * Replace group iteration with apply() optimized for distributed scenario * Fix test regressions * Linting/formatting/isort * Add repartitioning & allow writing into a single key * Minor - fix bucketing keys * Minor - Increase S3 select test timeout * Minor - read_parquet fix - replace pandas with modin in distributed mode
1 parent 938e83c commit 2865c85

File tree

10 files changed

+559
-107
lines changed

10 files changed

+559
-107
lines changed

awswrangler/_data_types.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def pyarrow2pandas_extension( # pylint: disable=too-many-branches,too-many-retu
454454
return None
455455

456456

457-
def pyarrow_types_from_pandas( # pylint: disable=too-many-branches
457+
def pyarrow_types_from_pandas( # pylint: disable=too-many-branches,too-many-statements
458458
df: pd.DataFrame, index: bool, ignore_cols: Optional[List[str]] = None, index_left: bool = False
459459
) -> Dict[str, pa.DataType]:
460460
"""Extract the related Pyarrow data types from any Pandas DataFrame."""
@@ -474,8 +474,14 @@ def pyarrow_types_from_pandas( # pylint: disable=too-many-branches
474474
cols_dtypes[name] = pa.int32()
475475
elif dtype == "Int64":
476476
cols_dtypes[name] = pa.int64()
477+
elif dtype == "float32":
478+
cols_dtypes[name] = pa.float32()
479+
elif dtype == "float64":
480+
cols_dtypes[name] = pa.float64()
477481
elif dtype == "string":
478482
cols_dtypes[name] = pa.string()
483+
elif dtype == "boolean":
484+
cols_dtypes[name] = pa.bool_()
479485
else:
480486
cols.append(name)
481487

awswrangler/_databases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def _records2df(
149149
if dtype[col_name] == pa.string() or isinstance(dtype[col_name], pa.Decimal128Type):
150150
col_values = oracle.handle_oracle_objects(col_values, col_name, dtype)
151151
array = pa.array(obj=col_values, type=dtype[col_name], safe=safe) # Creating Arrow array with dtype
152-
except pa.ArrowInvalid:
152+
except (pa.ArrowInvalid, pa.ArrowTypeError):
153153
array = pa.array(obj=col_values, safe=safe) # Creating Arrow array
154154
array = array.cast(target_type=dtype[col_name], safe=safe) # Casting
155155
arrays.append(array)
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
"""Distributed Datasources Module."""
22

3-
from awswrangler.distributed.datasources.parquet_datasource import ParquetDatasource
3+
from awswrangler.distributed.datasources.parquet_datasource import (
4+
ParquetDatasource,
5+
UserProvidedKeyBlockWritePathProvider,
6+
)
47

58
__all__ = [
69
"ParquetDatasource",
10+
"UserProvidedKeyBlockWritePathProvider",
711
]

awswrangler/distributed/datasources/parquet_datasource.py

Lines changed: 141 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,32 @@
11
"""Distributed ParquetDatasource Module."""
22

33
import logging
4-
from typing import Any, Callable, Iterator, List, Optional, Union
4+
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
55

66
import numpy as np
77
import pyarrow as pa
88

99
# fs required to implicitly trigger S3 subsystem initialization
1010
import pyarrow.fs # noqa: F401 pylint: disable=unused-import
1111
import pyarrow.parquet as pq
12-
from ray import cloudpickle
12+
from ray import cloudpickle # pylint: disable=wrong-import-order,ungrouped-imports
13+
from ray.data.block import Block, BlockAccessor, BlockMetadata
1314
from ray.data.context import DatasetContext
14-
from ray.data.datasource.datasource import ReadTask
15-
from ray.data.datasource.file_based_datasource import _resolve_paths_and_filesystem
15+
from ray.data.datasource import BlockWritePathProvider, DefaultBlockWritePathProvider
16+
from ray.data.datasource.datasource import ReadTask, WriteResult
17+
from ray.data.datasource.file_based_datasource import (
18+
_resolve_paths_and_filesystem,
19+
_S3FileSystemWrapper,
20+
_wrap_s3_serialization_workaround,
21+
)
1622
from ray.data.datasource.file_meta_provider import DefaultParquetMetadataProvider, ParquetMetadataProvider
1723
from ray.data.datasource.parquet_datasource import (
1824
_deregister_parquet_file_fragment_serialization,
1925
_register_parquet_file_fragment_serialization,
2026
)
2127
from ray.data.impl.output_buffer import BlockOutputBuffer
28+
from ray.data.impl.remote_fn import cached_remote_fn
29+
from ray.types import ObjectRef
2230

2331
from awswrangler._arrow import _add_table_partitions
2432

@@ -29,9 +37,31 @@
2937
PARQUET_READER_ROW_BATCH_SIZE = 100000
3038

3139

40+
class UserProvidedKeyBlockWritePathProvider(BlockWritePathProvider):
41+
"""Block write path provider.
42+
43+
Used when writing single-block datasets into a user-provided S3 key.
44+
"""
45+
46+
def _get_write_path_for_block(
47+
self,
48+
base_path: str,
49+
*,
50+
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
51+
dataset_uuid: Optional[str] = None,
52+
block: Optional[ObjectRef[Block[Any]]] = None,
53+
block_index: Optional[int] = None,
54+
file_format: Optional[str] = None,
55+
) -> str:
56+
return base_path
57+
58+
3259
class ParquetDatasource:
3360
"""Parquet datasource, for reading and writing Parquet files."""
3461

62+
def __init__(self) -> None:
63+
self._write_paths: List[str] = []
64+
3565
# Original: https://github.com/ray-project/ray/blob/releases/1.13.0/python/ray/data/datasource/parquet_datasource.py
3666
def prepare_read(
3767
self,
@@ -135,3 +165,110 @@ def read_pieces(serialized_pieces: str) -> Iterator[pa.Table]:
135165
_deregister_parquet_file_fragment_serialization() # type: ignore
136166

137167
return read_tasks
168+
169+
# Original implementation:
170+
# https://github.com/ray-project/ray/blob/releases/1.13.0/python/ray/data/datasource/file_based_datasource.py
171+
def do_write(
172+
self,
173+
blocks: List[ObjectRef[Block[Any]]],
174+
_: List[BlockMetadata],
175+
path: str,
176+
dataset_uuid: str,
177+
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
178+
try_create_dir: bool = True,
179+
open_stream_args: Optional[Dict[str, Any]] = None,
180+
block_path_provider: BlockWritePathProvider = DefaultBlockWritePathProvider(),
181+
write_args_fn: Callable[[], Dict[str, Any]] = lambda: {},
182+
_block_udf: Optional[Callable[[Block[Any]], Block[Any]]] = None,
183+
ray_remote_args: Optional[Dict[str, Any]] = None,
184+
**write_args: Any,
185+
) -> List[ObjectRef[WriteResult]]:
186+
"""Create write tasks for a parquet file datasource."""
187+
paths, filesystem = _resolve_paths_and_filesystem(path, filesystem)
188+
path = paths[0]
189+
if try_create_dir:
190+
filesystem.create_dir(path, recursive=True)
191+
filesystem = _wrap_s3_serialization_workaround(filesystem)
192+
193+
_write_block_to_file = self._write_block
194+
195+
if open_stream_args is None:
196+
open_stream_args = {}
197+
198+
if ray_remote_args is None:
199+
ray_remote_args = {}
200+
201+
def write_block(write_path: str, block: Block[Any]) -> str:
202+
_logger.debug("Writing %s file.", write_path)
203+
fs: Optional["pyarrow.fs.FileSystem"] = filesystem
204+
if isinstance(fs, _S3FileSystemWrapper):
205+
fs = fs.unwrap() # type: ignore
206+
if _block_udf is not None:
207+
block = _block_udf(block)
208+
209+
with fs.open_output_stream(write_path, **open_stream_args) as f:
210+
_write_block_to_file(
211+
f,
212+
BlockAccessor.for_block(block),
213+
writer_args_fn=write_args_fn,
214+
**write_args,
215+
)
216+
# This is a change from original FileBasedDatasource.do_write that does not return paths
217+
return write_path
218+
219+
write_block = cached_remote_fn(write_block).options(**ray_remote_args)
220+
221+
file_format = self._file_format()
222+
write_tasks = []
223+
for block_idx, block in enumerate(blocks):
224+
write_path = block_path_provider(
225+
path,
226+
filesystem=filesystem,
227+
dataset_uuid=dataset_uuid,
228+
block=block,
229+
block_index=block_idx,
230+
file_format=file_format,
231+
)
232+
write_task = write_block.remote(write_path, block) # type: ignore
233+
write_tasks.append(write_task)
234+
235+
return write_tasks
236+
237+
def on_write_complete(self, write_results: List[Any], **_: Any) -> None:
238+
"""Execute callback on write complete."""
239+
_logger.debug("Write complete %s.", write_results)
240+
# Collect and return all write task paths
241+
self._write_paths.extend(write_results)
242+
243+
def on_write_failed(self, write_results: List[ObjectRef[Any]], error: Exception, **_: Any) -> None:
244+
"""Execute callback on write failed."""
245+
_logger.debug("Write failed %s.", write_results)
246+
raise error
247+
248+
def get_write_paths(self) -> List[str]:
249+
"""Return S3 paths of where the results have been written."""
250+
return self._write_paths
251+
252+
def _write_block(
253+
self,
254+
f: "pyarrow.NativeFile",
255+
block: BlockAccessor[Any],
256+
writer_args_fn: Callable[[], Dict[str, Any]] = lambda: {},
257+
**writer_args: Any,
258+
) -> None:
259+
"""Write a block to S3."""
260+
import pyarrow.parquet as pq # pylint: disable=import-outside-toplevel,redefined-outer-name,reimported
261+
262+
writer_args = _resolve_kwargs(writer_args_fn, **writer_args)
263+
pq.write_table(block.to_arrow(), f, **writer_args)
264+
265+
def _file_format(self) -> str:
266+
"""Return file format."""
267+
return "parquet"
268+
269+
270+
def _resolve_kwargs(kwargs_fn: Callable[[], Dict[str, Any]], **kwargs: Any) -> Dict[str, Any]:
271+
if kwargs_fn:
272+
kwarg_overrides = kwargs_fn()
273+
kwargs.update(kwarg_overrides)
274+
return kwargs

awswrangler/s3/_read_parquet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
88

99
import boto3
10-
import pandas as pd
1110
import pyarrow as pa
1211
import pyarrow.dataset
1312
import pyarrow.parquet
@@ -30,10 +29,13 @@
3029
)
3130

3231
if config.distributed:
32+
import modin.pandas as pd
3333
from ray.data import read_datasource
3434

3535
from awswrangler.distributed._utils import _to_modin # pylint: disable=ungrouped-imports
3636
from awswrangler.distributed.datasources import ParquetDatasource
37+
else:
38+
import pandas as pd
3739

3840
BATCH_READ_BLOCK_SIZE = 65_536
3941
CHUNKED_READ_S3_BLOCK_SIZE = 10_485_760 # 10 MB (20 * 2**20)
@@ -323,7 +325,7 @@ def _read_parquet(
323325
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
324326
if config.distributed:
325327
dataset = read_datasource(
326-
datasource=ParquetDatasource(),
328+
datasource=ParquetDatasource(), # type: ignore
327329
parallelism=parallelism,
328330
use_threads=use_threads,
329331
paths=paths,

awswrangler/s3/_write.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,14 @@ def _validate_args(
5555
description: Optional[str],
5656
parameters: Optional[Dict[str, str]],
5757
columns_comments: Optional[Dict[str, str]],
58+
distributed: Optional[bool] = False,
5859
) -> None:
5960
if df.empty is True:
6061
raise exceptions.EmptyDataFrame("DataFrame cannot be empty.")
6162
if dataset is False:
6263
if path is None:
6364
raise exceptions.InvalidArgumentValue("If dataset is False, the `path` argument must be passed.")
64-
if path.endswith("/"):
65+
if not distributed and path.endswith("/"):
6566
raise exceptions.InvalidArgumentValue(
6667
"If <dataset=False>, the argument <path> should be a key, not a prefix."
6768
)

0 commit comments

Comments
 (0)