Skip to content

Commit 665141e

Browse files
Optimize distributed CSV I/O by adding PyArrow-based datasource (#1699)
1 parent 86e7b26 commit 665141e

File tree

6 files changed

+247
-21
lines changed

6 files changed

+247
-21
lines changed

awswrangler/distributed/ray/datasources/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Ray Datasources Module."""
22

3+
from awswrangler.distributed.ray.datasources.arrow_csv_datasource import ArrowCSVDatasource
34
from awswrangler.distributed.ray.datasources.pandas_file_based_datasource import UserProvidedKeyBlockWritePathProvider
45
from awswrangler.distributed.ray.datasources.pandas_text_datasource import (
56
PandasCSVDataSource,
@@ -10,6 +11,7 @@
1011
from awswrangler.distributed.ray.datasources.parquet_datasource import ParquetDatasource
1112

1213
__all__ = [
14+
"ArrowCSVDatasource",
1315
"PandasCSVDataSource",
1416
"PandasFWFDataSource",
1517
"PandasJSONDatasource",
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""Ray ArrowCSVDatasource Module."""
2+
from typing import Any, Iterator
3+
4+
import pyarrow as pa
5+
from pyarrow import csv
6+
from ray.data.block import BlockAccessor
7+
8+
from awswrangler._arrow import _add_table_partitions
9+
from awswrangler.distributed.ray.datasources.pandas_file_based_datasource import PandasFileBasedDatasource
10+
11+
12+
class ArrowCSVDatasource(PandasFileBasedDatasource): # pylint: disable=abstract-method
13+
"""CSV datasource, for reading and writing CSV files using PyArrow."""
14+
15+
_FILE_EXTENSION = "csv"
16+
17+
def _read_stream( # type: ignore # pylint: disable=arguments-differ
18+
self,
19+
f: pa.NativeFile,
20+
path: str,
21+
path_root: str,
22+
dataset: bool,
23+
**reader_args: Any,
24+
) -> Iterator[pa.Table]:
25+
read_options = reader_args.get("read_options", csv.ReadOptions(use_threads=False))
26+
parse_options = reader_args.get(
27+
"parse_options",
28+
csv.ParseOptions(),
29+
)
30+
convert_options = reader_args.get("convert_options", csv.ConvertOptions())
31+
32+
reader = csv.open_csv(
33+
f,
34+
read_options=read_options,
35+
parse_options=parse_options,
36+
convert_options=convert_options,
37+
)
38+
39+
schema = None
40+
while True:
41+
try:
42+
batch = reader.read_next_batch()
43+
table = pa.Table.from_batches([batch], schema=schema)
44+
if schema is None:
45+
schema = table.schema
46+
47+
if dataset:
48+
table = _add_table_partitions(
49+
table=table,
50+
path=f"s3://{path}",
51+
path_root=path_root,
52+
)
53+
54+
yield table
55+
56+
except StopIteration:
57+
return
58+
59+
def _write_block( # type: ignore # pylint: disable=arguments-differ
60+
self,
61+
f: pa.NativeFile,
62+
block: BlockAccessor[Any],
63+
**writer_args,
64+
) -> None:
65+
write_options_dict = writer_args.get("write_options", {})
66+
write_options = csv.WriteOptions(**write_options_dict)
67+
68+
csv.write_csv(block.to_arrow(), f, write_options)

awswrangler/distributed/ray/modin/_utils.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Modin on Ray utilities (PRIVATE)."""
2-
from typing import Any, Callable, Dict, List, Optional, Union
2+
from dataclasses import dataclass
3+
from typing import Any, Callable, Dict, List, Optional, Set, Union
34

45
import modin.pandas as modin_pd
56
import pandas as pd
@@ -8,6 +9,7 @@
89
from ray.data._internal.arrow_block import ArrowBlockAccessor, ArrowRow
910
from ray.data._internal.remote_fn import cached_remote_fn
1011

12+
from awswrangler import exceptions
1113
from awswrangler._arrow import _table_to_df
1214

1315

@@ -43,3 +45,30 @@ def _to_modin(
4345

4446
def _arrow_refs_to_df(arrow_refs: List[Callable[..., Any]], kwargs: Optional[Dict[str, Any]]) -> modin_pd.DataFrame:
4547
return _to_modin(dataset=ray.data.from_arrow_refs(arrow_refs), to_pandas_kwargs=kwargs)
48+
49+
50+
@dataclass
51+
class ParamConfig:
52+
"""
53+
Configuration for a Pandas argument that is supported in PyArrow.
54+
55+
Contains a default value and, optionally, a list of supports values.
56+
"""
57+
58+
default: Any
59+
supported_values: Optional[Set[Any]] = None
60+
61+
62+
def _check_parameters(pandas_kwargs: Dict[str, Any], supported_params: Dict[str, ParamConfig]) -> None:
63+
for pandas_arg_key, pandas_args_value in pandas_kwargs.items():
64+
if pandas_arg_key not in supported_params:
65+
raise exceptions.InvalidArgument(f"Unsupported Pandas parameter for PyArrow loader: {pandas_arg_key}")
66+
67+
param_config = supported_params[pandas_arg_key]
68+
if param_config.supported_values is None:
69+
continue
70+
71+
if pandas_args_value not in param_config.supported_values:
72+
raise exceptions.InvalidArgument(
73+
f"Unsupported Pandas parameter value for PyArrow loader: {pandas_arg_key}={pandas_args_value}",
74+
)

awswrangler/distributed/ray/modin/s3/_read_text.py

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,70 @@
11
"""Modin on Ray S3 read text module (PRIVATE)."""
2-
from typing import Any, Dict, List, Optional, Union
2+
import logging
3+
from typing import Any, Dict, List, Optional, Tuple, Union
34

45
import boto3
56
import modin.pandas as pd
7+
from pyarrow import csv
68
from ray.data import read_datasource
79

810
from awswrangler import exceptions
9-
from awswrangler.distributed.ray.datasources import PandasCSVDataSource, PandasFWFDataSource, PandasJSONDatasource
10-
from awswrangler.distributed.ray.modin._utils import _to_modin
11+
from awswrangler.distributed.ray.datasources import (
12+
ArrowCSVDatasource,
13+
PandasCSVDataSource,
14+
PandasFWFDataSource,
15+
PandasJSONDatasource,
16+
)
17+
from awswrangler.distributed.ray.modin._utils import ParamConfig, _check_parameters, _to_modin
1118

19+
_logger: logging.Logger = logging.getLogger(__name__)
1220

13-
def _resolve_format(read_format: str) -> Any:
21+
_CSV_SUPPORTED_PARAMS = {
22+
"sep": ParamConfig(default=","),
23+
"delimiter": ParamConfig(default=","),
24+
"quotechar": ParamConfig(default='"'),
25+
"doublequote": ParamConfig(default=True),
26+
}
27+
28+
29+
def _parse_csv_configuration(
30+
pandas_kwargs: Dict[str, Any],
31+
) -> Tuple[csv.ReadOptions, csv.ParseOptions, csv.ConvertOptions]:
32+
_check_parameters(pandas_kwargs, _CSV_SUPPORTED_PARAMS)
33+
34+
read_options = csv.ReadOptions(
35+
use_threads=False,
36+
)
37+
parse_options = csv.ParseOptions(
38+
delimiter=pandas_kwargs.get("sep", _CSV_SUPPORTED_PARAMS["sep"].default),
39+
quote_char=pandas_kwargs.get("quotechar", _CSV_SUPPORTED_PARAMS["quotechar"].default),
40+
double_quote=pandas_kwargs.get("doublequote", _CSV_SUPPORTED_PARAMS["doublequote"].default),
41+
)
42+
convert_options = csv.ConvertOptions()
43+
44+
return read_options, parse_options, convert_options
45+
46+
47+
def _parse_configuration(
48+
file_format: str,
49+
version_ids: Dict[str, Optional[str]],
50+
s3_additional_kwargs: Optional[Dict[str, str]],
51+
pandas_kwargs: Dict[str, Any],
52+
) -> Tuple[csv.ReadOptions, csv.ParseOptions, csv.ConvertOptions]:
53+
if {key: value for key, value in version_ids.items() if value is not None}:
54+
raise exceptions.InvalidArgument("Specific version ID found for object")
55+
56+
if s3_additional_kwargs:
57+
raise exceptions.InvalidArgument(f"Additional S3 args specified: {s3_additional_kwargs}")
58+
59+
if file_format == "csv":
60+
return _parse_csv_configuration(pandas_kwargs)
61+
62+
raise exceptions.InvalidArgument()
63+
64+
65+
def _resolve_format(read_format: str, can_use_arrow: bool) -> Any:
1466
if read_format == "csv":
15-
return PandasCSVDataSource()
67+
return ArrowCSVDatasource() if can_use_arrow else PandasCSVDataSource()
1668
if read_format == "fwf":
1769
return PandasFWFDataSource()
1870
if read_format == "json":
@@ -33,14 +85,34 @@ def _read_text_distributed( # pylint: disable=unused-argument
3385
use_threads: Union[bool, int],
3486
boto3_session: Optional["boto3.Session"],
3587
) -> pd.DataFrame:
36-
ds = read_datasource(
37-
datasource=_resolve_format(read_format),
88+
try:
89+
read_options, parse_options, convert_options = _parse_configuration(
90+
read_format,
91+
version_id_dict,
92+
s3_additional_kwargs,
93+
pandas_kwargs,
94+
)
95+
can_use_arrow = True
96+
except exceptions.InvalidArgument as e:
97+
_logger.warning(
98+
"PyArrow method unavailable, defaulting to Pandas I/O functions: %s. "
99+
"This will result in slower performance of the read operations",
100+
e,
101+
)
102+
read_options, parse_options, convert_options = None, None, None
103+
can_use_arrow = False
104+
105+
ray_dataset = read_datasource(
106+
datasource=_resolve_format(read_format, can_use_arrow),
38107
parallelism=parallelism,
39108
paths=paths,
40109
path_root=path_root,
41110
dataset=dataset,
42111
version_ids=version_id_dict,
43112
s3_additional_kwargs=s3_additional_kwargs,
44113
pandas_kwargs=pandas_kwargs,
114+
read_options=read_options,
115+
parse_options=parse_options,
116+
convert_options=convert_options,
45117
)
46-
return _to_modin(dataset=ds, ignore_index=ignore_index)
118+
return _to_modin(dataset=ray_dataset, ignore_index=ignore_index)

awswrangler/distributed/ray/modin/s3/_write_text.py

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,63 @@
1010

1111
from awswrangler import exceptions
1212
from awswrangler.distributed.ray.datasources import ( # pylint: disable=ungrouped-imports
13+
ArrowCSVDatasource,
1314
PandasCSVDataSource,
1415
PandasJSONDatasource,
15-
PandasTextDatasource,
1616
UserProvidedKeyBlockWritePathProvider,
1717
)
18+
from awswrangler.distributed.ray.datasources.pandas_file_based_datasource import PandasFileBasedDatasource
19+
from awswrangler.distributed.ray.modin._utils import ParamConfig, _check_parameters
1820
from awswrangler.s3._write import _COMPRESSION_2_EXT
1921
from awswrangler.s3._write_text import _get_write_details
2022

2123
_logger: logging.Logger = logging.getLogger(__name__)
2224

2325

26+
_CSV_SUPPORTED_PARAMS: Dict[str, ParamConfig] = {
27+
"header": ParamConfig(default=True),
28+
"sep": ParamConfig(default=",", supported_values={","}),
29+
"index": ParamConfig(default=True, supported_values={True}),
30+
"compression": ParamConfig(default=None, supported_values={None}),
31+
"quoting": ParamConfig(default=None, supported_values={None}),
32+
"escapechar": ParamConfig(default=None, supported_values={None}),
33+
"date_format": ParamConfig(default=None, supported_values={None}),
34+
}
35+
36+
37+
def _parse_csv_configuration(
38+
pandas_kwargs: Dict[str, Any],
39+
) -> Dict[str, Any]:
40+
_check_parameters(pandas_kwargs, _CSV_SUPPORTED_PARAMS)
41+
42+
# csv.WriteOptions cannot be pickled for some reason so we're building a Python dict
43+
return {
44+
"include_header": pandas_kwargs.get("header", _CSV_SUPPORTED_PARAMS["header"].default),
45+
}
46+
47+
48+
def _parse_configuration(
49+
file_format: str,
50+
s3_additional_kwargs: Optional[Dict[str, str]],
51+
pandas_kwargs: Dict[str, Any],
52+
) -> Dict[str, Any]:
53+
if s3_additional_kwargs:
54+
raise exceptions.InvalidArgument(f"Additional S3 args specified: {s3_additional_kwargs}")
55+
56+
if file_format == "csv":
57+
return _parse_csv_configuration(pandas_kwargs)
58+
59+
raise exceptions.InvalidArgument()
60+
61+
62+
def _datasource_for_format(read_format: str, can_use_arrow: bool) -> PandasFileBasedDatasource:
63+
if read_format == "csv":
64+
return ArrowCSVDatasource() if can_use_arrow else PandasCSVDataSource()
65+
if read_format == "json":
66+
return PandasJSONDatasource()
67+
raise exceptions.UnsupportedType("Unsupported read format")
68+
69+
2470
def _to_text_distributed( # pylint: disable=unused-argument
2571
df: pd.DataFrame,
2672
file_format: str,
@@ -63,16 +109,24 @@ def _to_text_distributed( # pylint: disable=unused-argument
63109
path,
64110
)
65111

66-
def _datasource_for_format(file_format: str) -> PandasTextDatasource:
67-
if file_format == "csv":
68-
return PandasCSVDataSource()
69-
70-
if file_format == "json":
71-
return PandasJSONDatasource()
72-
73-
raise RuntimeError(f"Unknown file format: {file_format}")
112+
# Figure out which data source to use, and convert PyArrow parameters if needed
113+
try:
114+
write_options = _parse_configuration(
115+
file_format,
116+
s3_additional_kwargs,
117+
pandas_kwargs,
118+
)
119+
can_use_arrow = True
120+
except exceptions.InvalidArgument as e:
121+
_logger.warning(
122+
"PyArrow method unavailable, defaulting to Pandas I/O functions: %s. "
123+
"This will result in slower performance of the write operations.",
124+
e,
125+
)
126+
write_options = None
127+
can_use_arrow = False
74128

75-
datasource = _datasource_for_format(file_format)
129+
datasource = _datasource_for_format(file_format, can_use_arrow)
76130

77131
mode, encoding, newline = _get_write_details(path=file_path, pandas_kwargs=pandas_kwargs)
78132
ds.write_datasource(
@@ -87,10 +141,11 @@ def _datasource_for_format(file_format: str) -> PandasTextDatasource:
87141
dataset_uuid=filename_prefix,
88142
boto3_session=None,
89143
s3_additional_kwargs=s3_additional_kwargs,
90-
mode=mode,
144+
mode="wb" if can_use_arrow else mode,
91145
encoding=encoding,
92146
newline=newline,
93147
pandas_kwargs=pandas_kwargs,
148+
write_options=write_options,
94149
)
95150

96151
return datasource.get_write_paths()

tests/load/test_s3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def test_s3_delete_objects(path, path2, benchmark_time):
127127
assert len(wr.s3.list_objects(f"{path2}delete-test*")) == 0
128128

129129

130-
@pytest.mark.parametrize("benchmark_time", [240])
130+
@pytest.mark.parametrize("benchmark_time", [30])
131131
def test_s3_read_csv_simple(benchmark_time):
132132
path = "s3://nyc-tlc/csv_backup/yellow_tripdata_2021-0*.csv"
133133
with ExecutionTimer("elapsed time of wr.s3.read_csv() simple") as timer:

0 commit comments

Comments
 (0)