Skip to content

Commit 7ae4baa

Browse files
authored
(refactor) ray datasources (#1687)
Refactor ray datasources: - Common base class for file-based datasources
1 parent a91ded1 commit 7ae4baa

File tree

5 files changed

+201
-249
lines changed

5 files changed

+201
-249
lines changed
Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
11
"""Ray Datasources Module."""
22

3+
from awswrangler.distributed.ray.datasources.pandas_file_based_datasource import UserProvidedKeyBlockWritePathProvider
34
from awswrangler.distributed.ray.datasources.pandas_text_datasource import (
45
PandasCSVDataSource,
56
PandasFWFDataSource,
67
PandasJSONDatasource,
78
PandasTextDatasource,
89
)
9-
from awswrangler.distributed.ray.datasources.parquet_datasource import (
10-
ParquetDatasource,
11-
UserProvidedKeyBlockWritePathProvider,
12-
)
10+
from awswrangler.distributed.ray.datasources.parquet_datasource import ParquetDatasource
1311

1412
__all__ = [
1513
"PandasCSVDataSource",
1614
"PandasFWFDataSource",
1715
"PandasJSONDatasource",
18-
"PandasTextDatasource",
1916
"ParquetDatasource",
17+
"PandasTextDatasource",
2018
"UserProvidedKeyBlockWritePathProvider",
2119
]
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
"""Ray PandasFileBasedDatasource Module."""
2+
import logging
3+
from typing import Any, Callable, Dict, List, Optional
4+
5+
import pandas as pd
6+
import pyarrow
7+
from ray.data._internal.pandas_block import PandasBlockAccessor
8+
from ray.data._internal.remote_fn import cached_remote_fn
9+
from ray.data.block import Block, BlockAccessor, BlockMetadata
10+
from ray.data.datasource.datasource import WriteResult
11+
from ray.data.datasource.file_based_datasource import (
12+
BlockWritePathProvider,
13+
DefaultBlockWritePathProvider,
14+
FileBasedDatasource,
15+
)
16+
from ray.types import ObjectRef
17+
18+
from awswrangler.s3._fs import open_s3_object
19+
from awswrangler.s3._write import _COMPRESSION_2_EXT
20+
21+
_logger: logging.Logger = logging.getLogger(__name__)
22+
23+
24+
class UserProvidedKeyBlockWritePathProvider(BlockWritePathProvider):
25+
"""Block write path provider.
26+
27+
Used when writing single-block datasets into a user-provided S3 key.
28+
"""
29+
30+
def _get_write_path_for_block(
31+
self,
32+
base_path: str,
33+
*,
34+
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
35+
dataset_uuid: Optional[str] = None,
36+
block: Optional[ObjectRef[Block[Any]]] = None,
37+
block_index: Optional[int] = None,
38+
file_format: Optional[str] = None,
39+
) -> str:
40+
return base_path
41+
42+
43+
class PandasFileBasedDatasource(FileBasedDatasource): # pylint: disable=abstract-method
44+
"""Pandas file based datasource, for reading and writing Pandas blocks."""
45+
46+
_FILE_EXTENSION: Optional[str] = None
47+
48+
def __init__(self) -> None:
49+
super().__init__()
50+
51+
self._write_paths: List[str] = []
52+
53+
def _read_file(self, f: pyarrow.NativeFile, path: str, **reader_args: Any) -> pd.DataFrame:
54+
raise NotImplementedError()
55+
56+
def do_write( # type: ignore # pylint: disable=arguments-differ
57+
self,
58+
blocks: List[ObjectRef[pd.DataFrame]],
59+
metadata: List[BlockMetadata],
60+
path: str,
61+
dataset_uuid: str,
62+
filesystem: Optional[pyarrow.fs.FileSystem] = None,
63+
try_create_dir: bool = True,
64+
open_stream_args: Optional[Dict[str, Any]] = None,
65+
block_path_provider: BlockWritePathProvider = DefaultBlockWritePathProvider(),
66+
write_args_fn: Callable[[], Dict[str, Any]] = lambda: {},
67+
_block_udf: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None,
68+
ray_remote_args: Optional[Dict[str, Any]] = None,
69+
s3_additional_kwargs: Optional[Dict[str, str]] = None,
70+
pandas_kwargs: Optional[Dict[str, Any]] = None,
71+
compression: Optional[str] = None,
72+
mode: str = "wb",
73+
**write_args: Any,
74+
) -> List[ObjectRef[WriteResult]]:
75+
"""Create and return write tasks for a file-based datasource."""
76+
_write_block_to_file = self._write_block
77+
78+
if ray_remote_args is None:
79+
ray_remote_args = {}
80+
81+
if pandas_kwargs is None:
82+
pandas_kwargs = {}
83+
84+
if not compression:
85+
compression = pandas_kwargs.get("compression")
86+
87+
def write_block(write_path: str, block: pd.DataFrame) -> str:
88+
_logger.debug("Writing %s file.", write_path)
89+
90+
if _block_udf is not None:
91+
block = _block_udf(block)
92+
93+
with open_s3_object(
94+
path=write_path,
95+
mode=mode,
96+
use_threads=False,
97+
s3_additional_kwargs=s3_additional_kwargs,
98+
encoding=write_args.get("encoding"),
99+
newline=write_args.get("newline"),
100+
) as f:
101+
_write_block_to_file(
102+
f,
103+
PandasBlockAccessor(block),
104+
pandas_kwargs=pandas_kwargs,
105+
compression=compression,
106+
**write_args,
107+
)
108+
return write_path
109+
110+
write_block_fn = cached_remote_fn(write_block).options(**ray_remote_args)
111+
112+
file_format = self._FILE_EXTENSION
113+
write_tasks = []
114+
115+
for block_idx, block in enumerate(blocks):
116+
write_path = block_path_provider(
117+
path,
118+
filesystem=filesystem,
119+
dataset_uuid=dataset_uuid,
120+
block=block,
121+
block_index=block_idx,
122+
file_format=f"{file_format}{_COMPRESSION_2_EXT.get(compression)}",
123+
)
124+
write_task = write_block_fn.remote(write_path, block)
125+
write_tasks.append(write_task)
126+
127+
return write_tasks
128+
129+
def _write_block(
130+
self,
131+
f: "pyarrow.NativeFile",
132+
block: BlockAccessor[Any],
133+
writer_args_fn: Callable[[], Dict[str, Any]] = lambda: {},
134+
**writer_args: Any,
135+
) -> None:
136+
raise NotImplementedError("Subclasses of PandasFileBasedDatasource must implement _write_block().")
137+
138+
def on_write_complete(self, write_results: List[Any], **_: Any) -> None:
139+
"""Execute callback on write complete."""
140+
_logger.debug("Write complete %s.", write_results)
141+
142+
# Collect and return all write task paths
143+
self._write_paths.extend(write_results)
144+
145+
def on_write_failed(self, write_results: List[ObjectRef[Any]], error: Exception, **_: Any) -> None:
146+
"""Execute callback on write failed."""
147+
_logger.debug("Write failed %s.", write_results)
148+
raise error
149+
150+
def get_write_paths(self) -> List[str]:
151+
"""Return S3 paths of where the results have been written."""
152+
return self._write_paths

0 commit comments

Comments
 (0)