Skip to content

Commit 9d5c1a6

Browse files
authored
feat: Upgrade to Ray 2.9.0+ and refactor Ray datasources to the new API (#2570)
* chore: Bump Ray to 2.8.1 Signed-off-by: Anton Kukushkin <[email protected]> * [skip ci] upgrade modin/pandas Signed-off-by: Anton Kukushkin <[email protected]> * [skip ci] fix imports Signed-off-by: Anton Kukushkin <[email protected]> * [skip ci] checkpoint Signed-off-by: Anton Kukushkin <[email protected]> * [skip ci] Write API changes: add ParquetDatasink Signed-off-by: Anton Kukushkin <[email protected]> * [skip ci] Bump to 2.9.0 Signed-off-by: Anton Kukushkin <[email protected]> * [skip ci] Minor refactoring of ParquetDatasink Signed-off-by: Anton Kukushkin <[email protected]> * [skip ci] Checkpoint - add CSV and JSON datasink Signed-off-by: Anton Kukushkin <[email protected]> * [skip ci] Extend from _BlockFileDatasink Signed-off-by: Anton Kukushkin <[email protected]> * [skip ci] Fix parquet params Signed-off-by: Anton Kukushkin <[email protected]> * [skip ci] Fix text datasinks & add ArrowCSVDatasink Signed-off-by: Anton Kukushkin <[email protected]> * [skip ci] Refactor text datasources to the new API Signed-off-by: Anton Kukushkin <[email protected]> * [skip ci] Text datasource fixes Signed-off-by: Anton Kukushkin <[email protected]> * [skip ci] Add ORC datasink; update ORC datasource to the new API; refactoring Signed-off-by: Anton Kukushkin <[email protected]> * [skip ci] Minor fixes Signed-off-by: Anton Kukushkin <[email protected]> * [skip ci] Checkpoint - adapt Parquet datasources Signed-off-by: Anton Kukushkin <[email protected]> * [skip ci] Fix output metadata handling & minor fixes Signed-off-by: Anton Kukushkin <[email protected]> * [skip ci] Mypy Signed-off-by: Anton Kukushkin <[email protected]> * Tests - JSON 'index=True' is only valid when 'orient' is 'split', 'table', 'index' or 'columns'. Signed-off-by: Anton Kukushkin <[email protected]> * [skip ci] Fix text datasink compression file extension Signed-off-by: Anton Kukushkin <[email protected]> * [skip ci] Fix ORC test named index serialization Signed-off-by: Anton Kukushkin <[email protected]> * [skip ci] Fix pyproject.toml & poetry.lock marking pandas as optional Signed-off-by: Anton Kukushkin <[email protected]> * Inconsistent schema resolution on Modin dfs Signed-off-by: Anton Kukushkin <[email protected]> * [skip ci] Miror refactoring - arrow parquet props Signed-off-by: Anton Kukushkin <[email protected]> * [skip ci] Open s3 object props Signed-off-by: Anton Kukushkin <[email protected]> * [skip ci] Move CSV write options Signed-off-by: Anton Kukushkin <[email protected]> * Minor - index & dtype handling Signed-off-by: Anton Kukushkin <[email protected]> * Fix poetry.lock Signed-off-by: Anton Kukushkin <[email protected]> --------- Signed-off-by: Anton Kukushkin <[email protected]>
1 parent df8b76a commit 9d5c1a6

28 files changed

+1218
-853
lines changed

awswrangler/distributed/ray/datasources/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
"""Ray Datasources Module."""
22

3+
from awswrangler.distributed.ray.datasources.arrow_csv_datasink import ArrowCSVDatasink
34
from awswrangler.distributed.ray.datasources.arrow_csv_datasource import ArrowCSVDatasource
45
from awswrangler.distributed.ray.datasources.arrow_json_datasource import ArrowJSONDatasource
6+
from awswrangler.distributed.ray.datasources.arrow_orc_datasink import ArrowORCDatasink
57
from awswrangler.distributed.ray.datasources.arrow_orc_datasource import ArrowORCDatasource
68
from awswrangler.distributed.ray.datasources.arrow_parquet_base_datasource import ArrowParquetBaseDatasource
9+
from awswrangler.distributed.ray.datasources.arrow_parquet_datasink import ArrowParquetDatasink
710
from awswrangler.distributed.ray.datasources.arrow_parquet_datasource import ArrowParquetDatasource
8-
from awswrangler.distributed.ray.datasources.pandas_file_based_datasource import UserProvidedKeyBlockWritePathProvider
11+
from awswrangler.distributed.ray.datasources.block_path_provider import UserProvidedKeyBlockWritePathProvider
12+
from awswrangler.distributed.ray.datasources.file_datasink import _BlockFileDatasink
13+
from awswrangler.distributed.ray.datasources.pandas_text_datasink import PandasCSVDatasink, PandasJSONDatasink
914
from awswrangler.distributed.ray.datasources.pandas_text_datasource import (
1015
PandasCSVDataSource,
1116
PandasFWFDataSource,
@@ -14,6 +19,9 @@
1419
)
1520

1621
__all__ = [
22+
"ArrowCSVDatasink",
23+
"ArrowORCDatasink",
24+
"ArrowParquetDatasink",
1725
"ArrowCSVDatasource",
1826
"ArrowJSONDatasource",
1927
"ArrowORCDatasource",
@@ -24,4 +32,7 @@
2432
"PandasJSONDatasource",
2533
"PandasTextDatasource",
2634
"UserProvidedKeyBlockWritePathProvider",
35+
"PandasCSVDatasink",
36+
"PandasJSONDatasink",
37+
"_BlockFileDatasink",
2738
]
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""Ray PandasTextDatasink Module."""
2+
3+
import io
4+
import logging
5+
from typing import Any, Dict, Optional
6+
7+
from pyarrow import csv
8+
from ray.data.block import BlockAccessor
9+
from ray.data.datasource.block_path_provider import BlockWritePathProvider
10+
11+
from awswrangler.distributed.ray.datasources.file_datasink import _BlockFileDatasink
12+
13+
_logger: logging.Logger = logging.getLogger(__name__)
14+
15+
16+
class ArrowCSVDatasink(_BlockFileDatasink):
17+
"""A datasink that writes CSV files using Arrow."""
18+
19+
def __init__(
20+
self,
21+
path: str,
22+
*,
23+
block_path_provider: Optional[BlockWritePathProvider] = None,
24+
dataset_uuid: Optional[str] = None,
25+
open_s3_object_args: Optional[Dict[str, Any]] = None,
26+
pandas_kwargs: Optional[Dict[str, Any]] = None,
27+
write_options: Optional[Dict[str, Any]] = None,
28+
**write_args: Any,
29+
):
30+
super().__init__(
31+
path,
32+
file_format="csv",
33+
block_path_provider=block_path_provider,
34+
dataset_uuid=dataset_uuid,
35+
open_s3_object_args=open_s3_object_args,
36+
pandas_kwargs=pandas_kwargs,
37+
**write_args,
38+
)
39+
40+
self.write_options = write_options or {}
41+
42+
def write_block(self, file: io.TextIOWrapper, block: BlockAccessor) -> None:
43+
"""
44+
Write a block of data to a file.
45+
46+
Parameters
47+
----------
48+
block : BlockAccessor
49+
file : io.TextIOWrapper
50+
"""
51+
csv.write_csv(block.to_arrow(), file, csv.WriteOptions(**self.write_options))
Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,50 @@
11
"""Ray ArrowCSVDatasource Module."""
2-
from typing import Any, Iterator
2+
from typing import Any, Dict, Iterator, List, Optional, Union
33

44
import pyarrow as pa
55
from pyarrow import csv
6-
from ray.data.block import BlockAccessor
6+
from ray.data.datasource.file_based_datasource import FileBasedDatasource
77

88
from awswrangler._arrow import _add_table_partitions
9-
from awswrangler.distributed.ray.datasources.pandas_file_based_datasource import PandasFileBasedDatasource
109

1110

12-
class ArrowCSVDatasource(PandasFileBasedDatasource): # pylint: disable=abstract-method
13-
"""CSV datasource, for reading and writing CSV files using PyArrow."""
11+
class ArrowCSVDatasource(FileBasedDatasource):
12+
"""CSV datasource, for reading CSV files using PyArrow."""
1413

15-
_FILE_EXTENSION = "csv"
14+
_FILE_EXTENSIONS = ["csv"]
1615

17-
def _read_stream( # type: ignore[override] # pylint: disable=arguments-differ
16+
def __init__(
1817
self,
19-
f: pa.NativeFile,
20-
path: str,
21-
path_root: str,
18+
paths: Union[str, List[str]],
2219
dataset: bool,
23-
**reader_args: Any,
24-
) -> Iterator[pa.Table]:
25-
read_options = reader_args.get("read_options", csv.ReadOptions(use_threads=False))
26-
parse_options = reader_args.get(
27-
"parse_options",
28-
csv.ParseOptions(),
29-
)
30-
convert_options = reader_args.get("convert_options", csv.ConvertOptions())
20+
path_root: str,
21+
version_ids: Optional[Dict[str, str]] = None,
22+
s3_additional_kwargs: Optional[Dict[str, str]] = None,
23+
pandas_kwargs: Optional[Dict[str, Any]] = None,
24+
arrow_csv_args: Optional[Dict[str, Any]] = None,
25+
**file_based_datasource_kwargs: Any,
26+
):
27+
from pyarrow import csv
28+
29+
super().__init__(paths, **file_based_datasource_kwargs)
30+
31+
self.dataset = dataset
32+
self.path_root = path_root
3133

34+
if arrow_csv_args is None:
35+
arrow_csv_args = {}
36+
37+
self.read_options = arrow_csv_args.pop("read_options", csv.ReadOptions(use_threads=False))
38+
self.parse_options = arrow_csv_args.pop("parse_options", csv.ParseOptions())
39+
self.convert_options = arrow_csv_args.get("convert_options", csv.ConvertOptions())
40+
self.arrow_csv_args = arrow_csv_args
41+
42+
def _read_stream(self, f: pa.NativeFile, path: str) -> Iterator[pa.Table]:
3243
reader = csv.open_csv(
3344
f,
34-
read_options=read_options,
35-
parse_options=parse_options,
36-
convert_options=convert_options,
45+
read_options=self.read_options,
46+
parse_options=self.parse_options,
47+
convert_options=self.convert_options,
3748
)
3849

3950
schema = None
@@ -44,25 +55,14 @@ def _read_stream( # type: ignore[override] # pylint: disable=arguments-differ
4455
if schema is None:
4556
schema = table.schema
4657

47-
if dataset:
58+
if self.dataset:
4859
table = _add_table_partitions(
4960
table=table,
5061
path=f"s3://{path}",
51-
path_root=path_root,
62+
path_root=self.path_root,
5263
)
5364

5465
yield table
5566

5667
except StopIteration:
5768
return
58-
59-
def _write_block( # type: ignore[override] # pylint: disable=arguments-differ
60-
self,
61-
f: pa.NativeFile,
62-
block: BlockAccessor,
63-
**writer_args: Any,
64-
) -> None:
65-
write_options_dict = writer_args.get("write_options", {})
66-
write_options = csv.WriteOptions(**write_options_dict)
67-
68-
csv.write_csv(block.to_arrow(), f, write_options)
Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,49 @@
11
"""Ray ArrowCSVDatasource Module."""
2-
from typing import Any
2+
from typing import Any, Dict, Iterator, List, Optional, Union
33

44
import pyarrow as pa
55
from pyarrow import json
6+
from ray.data.datasource.file_based_datasource import FileBasedDatasource
67

78
from awswrangler._arrow import _add_table_partitions
8-
from awswrangler.distributed.ray.datasources.pandas_file_based_datasource import PandasFileBasedDatasource
99

1010

11-
class ArrowJSONDatasource(PandasFileBasedDatasource): # pylint: disable=abstract-method
12-
"""JSON datasource, for reading and writing JSON files using PyArrow."""
11+
class ArrowJSONDatasource(FileBasedDatasource): # pylint: disable=abstract-method
12+
"""JSON datasource, for reading JSON files using PyArrow."""
1313

14-
_FILE_EXTENSION = "json"
14+
_FILE_EXTENSIONS = ["json"]
1515

16-
def _read_file( # type: ignore[override] # pylint: disable=arguments-differ
16+
def __init__(
1717
self,
18-
f: pa.NativeFile,
19-
path: str,
20-
path_root: str,
18+
paths: Union[str, List[str]],
2119
dataset: bool,
22-
**reader_args: Any,
23-
) -> pa.Table:
24-
read_options_dict = reader_args.get("read_options", dict(use_threads=False))
25-
parse_options_dict = reader_args.get("parse_options", {})
20+
path_root: str,
21+
version_ids: Optional[Dict[str, str]] = None,
22+
s3_additional_kwargs: Optional[Dict[str, str]] = None,
23+
pandas_kwargs: Optional[Dict[str, Any]] = None,
24+
arrow_json_args: Optional[Dict[str, Any]] = None,
25+
**file_based_datasource_kwargs: Any,
26+
):
27+
super().__init__(paths, **file_based_datasource_kwargs)
28+
29+
self.dataset = dataset
30+
self.path_root = path_root
31+
32+
if arrow_json_args is None:
33+
arrow_json_args = {}
2634

27-
read_options = json.ReadOptions(**read_options_dict)
28-
parse_options = json.ParseOptions(**parse_options_dict)
35+
self.read_options = json.ReadOptions(arrow_json_args.pop("read_options", dict(use_threads=False)))
36+
self.parse_options = json.ParseOptions(arrow_json_args.pop("parse_options", {}))
37+
self.arrow_json_args = arrow_json_args
2938

30-
table = json.read_json(f, read_options=read_options, parse_options=parse_options)
39+
def _read_stream(self, f: pa.NativeFile, path: str) -> Iterator[pa.Table]:
40+
table = json.read_json(f, read_options=self.read_options, parse_options=self.parse_options)
3141

32-
if dataset:
42+
if self.dataset:
3343
table = _add_table_partitions(
3444
table=table,
3545
path=f"s3://{path}",
36-
path_root=path_root,
46+
path_root=self.path_root,
3747
)
3848

39-
return table
49+
return [table] # type: ignore[return-value]
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Ray PandasTextDatasink Module."""
2+
3+
import io
4+
import logging
5+
from typing import Any, Dict, Optional
6+
7+
import pyarrow as pa
8+
from ray.data.block import BlockAccessor
9+
from ray.data.datasource.block_path_provider import BlockWritePathProvider
10+
11+
from awswrangler._arrow import _df_to_table
12+
from awswrangler.distributed.ray.datasources.file_datasink import _BlockFileDatasink
13+
14+
_logger: logging.Logger = logging.getLogger(__name__)
15+
16+
17+
class ArrowORCDatasink(_BlockFileDatasink):
18+
"""A datasink that writes CSV files using Arrow."""
19+
20+
def __init__(
21+
self,
22+
path: str,
23+
*,
24+
block_path_provider: Optional[BlockWritePathProvider] = None,
25+
dataset_uuid: Optional[str] = None,
26+
open_s3_object_args: Optional[Dict[str, Any]] = None,
27+
pandas_kwargs: Optional[Dict[str, Any]] = None,
28+
schema: Optional[pa.Schema] = None,
29+
index: bool = False,
30+
dtype: Optional[Dict[str, str]] = None,
31+
pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None,
32+
**write_args: Any,
33+
):
34+
super().__init__(
35+
path,
36+
file_format="orc",
37+
block_path_provider=block_path_provider,
38+
dataset_uuid=dataset_uuid,
39+
open_s3_object_args=open_s3_object_args,
40+
pandas_kwargs=pandas_kwargs,
41+
**write_args,
42+
)
43+
44+
self.pyarrow_additional_kwargs = pyarrow_additional_kwargs or {}
45+
self.schema = schema
46+
self.index = index
47+
self.dtype = dtype
48+
49+
def write_block(self, file: io.TextIOWrapper, block: BlockAccessor) -> None:
50+
"""
51+
Write a block of data to a file.
52+
53+
Parameters
54+
----------
55+
file : io.TextIOWrapper
56+
block : BlockAccessor
57+
"""
58+
from pyarrow import orc
59+
60+
compression: str = self.write_args.get("compression", None) or "UNCOMPRESSED"
61+
62+
orc.write_table(
63+
_df_to_table(block.to_pandas(), schema=self.schema, index=self.index, dtype=self.dtype),
64+
file,
65+
compression=compression,
66+
**self.pyarrow_additional_kwargs,
67+
)

0 commit comments

Comments
 (0)