Skip to content

Commit 5a1f275

Browse files
authored
(feat): Refactor to distribute s3.read_parquet (#1513)
* (feat): Refactor and distribute s3.read_parquet 1. Refactor "wr.s3.read_parquet" and other methods in "_read_parquet" S3 module to reduce technical debt: - Leverage thread pool executor when possible - Simplify chunk generation logic - Reduce number of conditionals by generalising edge cases - Improve documentation 2. Distribute both "read_file_metadata" and "read_parquet" calls - "read_file_metadata" is distributed as a "@ray_remote" method via the executor - "read_parquet" is distributed using a custom datasource and the "read_datasource" Ray public API
1 parent 8e3b4aa commit 5a1f275

File tree

17 files changed

+676
-705
lines changed

17 files changed

+676
-705
lines changed

awswrangler/_arrow.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""Arrow Utilities Module (PRIVATE)."""
2+
3+
import datetime
4+
import json
5+
import logging
6+
from typing import Any, Dict, Optional, Tuple, cast
7+
8+
import pandas as pd
9+
import pyarrow as pa
10+
11+
_logger: logging.Logger = logging.getLogger(__name__)
12+
13+
14+
def _extract_partitions_from_path(path_root: str, path: str) -> Dict[str, str]:
15+
path_root = path_root if path_root.endswith("/") else f"{path_root}/"
16+
if path_root not in path:
17+
raise Exception(f"Object {path} is not under the root path ({path_root}).")
18+
path_wo_filename: str = path.rpartition("/")[0] + "/"
19+
path_wo_prefix: str = path_wo_filename.replace(f"{path_root}/", "")
20+
dirs: Tuple[str, ...] = tuple(x for x in path_wo_prefix.split("/") if (x != "") and (x.count("=") == 1))
21+
if not dirs:
22+
return {}
23+
values_tups = cast(Tuple[Tuple[str, str]], tuple(tuple(x.split("=")[:2]) for x in dirs))
24+
values_dics: Dict[str, str] = dict(values_tups)
25+
return values_dics
26+
27+
28+
def _add_table_partitions(
29+
table: pa.Table,
30+
path: str,
31+
path_root: Optional[str],
32+
) -> pa.Table:
33+
part = _extract_partitions_from_path(path_root, path) if path_root else None
34+
if part:
35+
for col, value in part.items():
36+
part_value = pa.array([value] * len(table)).dictionary_encode()
37+
if col not in table.schema.names:
38+
table = table.append_column(col, part_value)
39+
else:
40+
table = table.set_column(
41+
table.schema.get_field_index(col),
42+
col,
43+
part_value,
44+
)
45+
return table
46+
47+
48+
def _apply_timezone(df: pd.DataFrame, metadata: Dict[str, Any]) -> pd.DataFrame:
49+
for c in metadata["columns"]:
50+
if "field_name" in c and c["field_name"] is not None:
51+
col_name = str(c["field_name"])
52+
elif "name" in c and c["name"] is not None:
53+
col_name = str(c["name"])
54+
else:
55+
continue
56+
if col_name in df.columns and c["pandas_type"] == "datetimetz":
57+
timezone: datetime.tzinfo = pa.lib.string_to_tzinfo(c["metadata"]["timezone"])
58+
_logger.debug("applying timezone (%s) on column %s", timezone, col_name)
59+
if hasattr(df[col_name].dtype, "tz") is False:
60+
df[col_name] = df[col_name].dt.tz_localize(tz="UTC")
61+
df[col_name] = df[col_name].dt.tz_convert(tz=timezone)
62+
return df
63+
64+
65+
def _table_to_df(
66+
table: pa.Table,
67+
kwargs: Dict[str, Any],
68+
) -> pd.DataFrame:
69+
"""Convert a PyArrow table to a Pandas DataFrame and apply metadata.
70+
71+
This method should be used across to codebase to ensure this conversion is consistent.
72+
"""
73+
metadata: Dict[str, Any] = {}
74+
if table.schema.metadata is not None and b"pandas" in table.schema.metadata:
75+
metadata = json.loads(table.schema.metadata[b"pandas"])
76+
77+
df = table.to_pandas(**kwargs)
78+
79+
if metadata:
80+
_logger.debug("metadata: %s", metadata)
81+
df = _apply_timezone(df=df, metadata=metadata)
82+
return df

awswrangler/_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from awswrangler import _config, exceptions
2020
from awswrangler.__metadata__ import __version__
21+
from awswrangler._arrow import _table_to_df
2122
from awswrangler._config import apply_configs, config
2223

2324
if TYPE_CHECKING or config.distributed:
@@ -416,7 +417,7 @@ def table_refs_to_df(
416417
) -> pd.DataFrame:
417418
"""Build Pandas dataframe from list of PyArrow tables."""
418419
if isinstance(tables[0], pa.Table):
419-
return ensure_df_is_mutable(pa.concat_tables(tables, promote=True).to_pandas(**kwargs))
420+
return _table_to_df(pa.concat_tables(tables, promote=True), kwargs=kwargs)
420421
return _arrow_refs_to_df(arrow_refs=tables, kwargs=kwargs) # type: ignore
421422

422423

awswrangler/athena/_read.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,15 @@ def _fetch_parquet_result(
109109
df = cast_pandas_with_athena_types(df=df, dtype=dtype_dict)
110110
df = _apply_query_metadata(df=df, query_metadata=query_metadata)
111111
return df
112+
if not pyarrow_additional_kwargs:
113+
pyarrow_additional_kwargs = {}
114+
if categories:
115+
pyarrow_additional_kwargs["categories"] = categories
112116
ret = s3.read_parquet(
113117
path=paths,
114118
use_threads=use_threads,
115119
boto3_session=boto3_session,
116120
chunked=chunked,
117-
categories=categories,
118-
ignore_index=True,
119121
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
120122
)
121123
if chunked is False:

awswrangler/distributed/_utils.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,33 @@
11
"""Utilities Module for Distributed methods."""
22

3-
from typing import Any, Callable, Dict, List, Optional
3+
from typing import Any, Callable, Dict, List
44

55
import modin.pandas as pd
66
import pyarrow as pa
77
import ray
88
from modin.distributed.dataframe.pandas.partitions import from_partitions
9-
from ray.data.impl.arrow_block import ArrowBlockAccessor
9+
from ray.data.impl.arrow_block import ArrowBlockAccessor, ArrowRow
1010
from ray.data.impl.remote_fn import cached_remote_fn
1111

12+
from awswrangler._arrow import _table_to_df
13+
1214

1315
def _block_to_df(
1416
block: Any,
1517
kwargs: Dict[str, Any],
16-
dtype: Optional[Dict[str, str]] = None,
1718
) -> pa.Table:
1819
block = ArrowBlockAccessor.for_block(block)
19-
df = block._table.to_pandas(**kwargs) # pylint: disable=protected-access
20-
return df.astype(dtype=dtype) if dtype else df
20+
return _table_to_df(table=block._table, kwargs=kwargs) # pylint: disable=protected-access
2121

2222

23-
def _arrow_refs_to_df(arrow_refs: List[Callable[..., Any]], kwargs: Dict[str, Any]) -> pd.DataFrame:
24-
ds = ray.data.from_arrow_refs(arrow_refs)
23+
def _to_modin(dataset: ray.data.Dataset[ArrowRow], kwargs: Dict[str, Any]) -> pd.DataFrame:
2524
block_to_df = cached_remote_fn(_block_to_df)
2625
return from_partitions(
27-
partitions=[block_to_df.remote(block=block, kwargs=kwargs) for block in ds.get_internal_block_refs()],
26+
partitions=[block_to_df.remote(block=block, kwargs=kwargs) for block in dataset.get_internal_block_refs()],
2827
axis=0,
29-
index=pd.RangeIndex(start=0, stop=ds.count()),
28+
index=pd.RangeIndex(start=0, stop=dataset.count()),
3029
)
30+
31+
32+
def _arrow_refs_to_df(arrow_refs: List[Callable[..., Any]], kwargs: Dict[str, Any]) -> pd.DataFrame:
33+
return _to_modin(dataset=ray.data.from_arrow_refs(arrow_refs), kwargs=kwargs)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Distributed Datasources Module."""
2+
3+
from awswrangler.distributed.datasources.parquet_datasource import ParquetDatasource
4+
5+
__all__ = [
6+
"ParquetDatasource",
7+
]
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""Distributed ParquetDatasource Module."""
2+
3+
import logging
4+
from typing import Any, Callable, Iterator, List, Optional, Union
5+
6+
import numpy as np
7+
import pyarrow as pa
8+
9+
# fs required to implicitly trigger S3 subsystem initialization
10+
import pyarrow.fs # noqa: F401 pylint: disable=unused-import
11+
import pyarrow.parquet as pq
12+
from ray import cloudpickle
13+
from ray.data.context import DatasetContext
14+
from ray.data.datasource.datasource import ReadTask
15+
from ray.data.datasource.file_based_datasource import _resolve_paths_and_filesystem
16+
from ray.data.datasource.file_meta_provider import DefaultParquetMetadataProvider, ParquetMetadataProvider
17+
from ray.data.datasource.parquet_datasource import (
18+
_deregister_parquet_file_fragment_serialization,
19+
_register_parquet_file_fragment_serialization,
20+
)
21+
from ray.data.impl.output_buffer import BlockOutputBuffer
22+
23+
from awswrangler._arrow import _add_table_partitions
24+
25+
_logger: logging.Logger = logging.getLogger(__name__)
26+
27+
# The number of rows to read per batch. This is sized to generate 10MiB batches
28+
# for rows about 1KiB in size.
29+
PARQUET_READER_ROW_BATCH_SIZE = 100000
30+
31+
32+
class ParquetDatasource:
33+
"""Parquet datasource, for reading and writing Parquet files."""
34+
35+
# Original: https://github.com/ray-project/ray/blob/releases/1.13.0/python/ray/data/datasource/parquet_datasource.py
36+
def prepare_read(
37+
self,
38+
parallelism: int,
39+
use_threads: Union[bool, int],
40+
paths: Union[str, List[str]],
41+
schema: "pyarrow.lib.Schema",
42+
columns: Optional[List[str]] = None,
43+
coerce_int96_timestamp_unit: Optional[str] = None,
44+
path_root: Optional[str] = None,
45+
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
46+
meta_provider: ParquetMetadataProvider = DefaultParquetMetadataProvider(),
47+
_block_udf: Optional[Callable[..., Any]] = None,
48+
) -> List[ReadTask]:
49+
"""Create and return read tasks for a Parquet file-based datasource."""
50+
paths, filesystem = _resolve_paths_and_filesystem(paths, filesystem)
51+
52+
parquet_dataset = pq.ParquetDataset(
53+
path_or_paths=paths,
54+
filesystem=filesystem,
55+
partitioning=None,
56+
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit,
57+
use_legacy_dataset=False,
58+
)
59+
60+
def read_pieces(serialized_pieces: str) -> Iterator[pa.Table]:
61+
# Deserialize after loading the filesystem class.
62+
try:
63+
_register_parquet_file_fragment_serialization() # type: ignore
64+
pieces = cloudpickle.loads(serialized_pieces)
65+
finally:
66+
_deregister_parquet_file_fragment_serialization() # type: ignore
67+
68+
# Ensure that we're reading at least one dataset fragment.
69+
assert len(pieces) > 0
70+
71+
ctx = DatasetContext.get_current()
72+
output_buffer = BlockOutputBuffer(block_udf=_block_udf, target_max_block_size=ctx.target_max_block_size)
73+
74+
_logger.debug("Reading %s parquet pieces", len(pieces))
75+
for piece in pieces:
76+
batches = piece.to_batches(
77+
use_threads=use_threads,
78+
columns=columns,
79+
schema=schema,
80+
batch_size=PARQUET_READER_ROW_BATCH_SIZE,
81+
)
82+
for batch in batches:
83+
# Table creation is wrapped inside _add_table_partitions
84+
# to add columns with partition values when dataset=True
85+
table = _add_table_partitions(
86+
table=pa.Table.from_batches([batch], schema=schema),
87+
path=f"s3://{piece.path}",
88+
path_root=path_root,
89+
)
90+
# If the table is empty, drop it.
91+
if table.num_rows > 0:
92+
output_buffer.add_block(table)
93+
if output_buffer.has_next():
94+
yield output_buffer.next()
95+
96+
output_buffer.finalize()
97+
if output_buffer.has_next():
98+
yield output_buffer.next()
99+
100+
if _block_udf is not None:
101+
# Try to infer dataset schema by passing dummy table through UDF.
102+
dummy_table = schema.empty_table()
103+
try:
104+
inferred_schema = _block_udf(dummy_table).schema
105+
inferred_schema = inferred_schema.with_metadata(schema.metadata)
106+
except Exception: # pylint: disable=broad-except
107+
_logger.debug(
108+
"Failed to infer schema of dataset by passing dummy table "
109+
"through UDF due to the following exception:",
110+
exc_info=True,
111+
)
112+
inferred_schema = schema
113+
else:
114+
inferred_schema = schema
115+
read_tasks = []
116+
metadata = meta_provider.prefetch_file_metadata(parquet_dataset.pieces) or []
117+
try:
118+
_register_parquet_file_fragment_serialization() # type: ignore
119+
for pieces, metadata in zip( # type: ignore
120+
np.array_split(parquet_dataset.pieces, parallelism),
121+
np.array_split(metadata, parallelism),
122+
):
123+
if len(pieces) <= 0:
124+
continue
125+
serialized_pieces = cloudpickle.dumps(pieces) # type: ignore
126+
input_files = [p.path for p in pieces]
127+
meta = meta_provider(
128+
input_files,
129+
inferred_schema,
130+
pieces=pieces,
131+
prefetched_metadata=metadata,
132+
)
133+
read_tasks.append(ReadTask(lambda p=serialized_pieces: read_pieces(p), meta)) # type: ignore
134+
finally:
135+
_deregister_parquet_file_fragment_serialization() # type: ignore
136+
137+
return read_tasks

awswrangler/lakeformation/_read.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def read_sql_query(
8383
use_threads: bool = True,
8484
boto3_session: Optional[boto3.Session] = None,
8585
params: Optional[Dict[str, Any]] = None,
86-
arrow_additional_kwargs: Optional[Dict[str, Any]] = None,
86+
pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None,
8787
) -> pd.DataFrame:
8888
"""Execute PartiQL query on AWS Glue Table (Transaction ID or time travel timestamp). Return Pandas DataFrame.
8989
@@ -126,10 +126,10 @@ def read_sql_query(
126126
Dict of parameters used to format the partiQL query. Only named parameters are supported.
127127
The dict must contain the information in the form {"name": "value"} and the SQL query must contain
128128
`:name`.
129-
arrow_additional_kwargs : Dict[str, Any], optional
129+
pyarrow_additional_kwargs : Dict[str, Any], optional
130130
Forwarded to `to_pandas` method converting from PyArrow tables to Pandas dataframe.
131131
Valid values include "split_blocks", "self_destruct", "ignore_metadata".
132-
e.g. arrow_additional_kwargs={'split_blocks': True}.
132+
e.g. pyarrow_additional_kwargs={'split_blocks': True}.
133133
134134
Returns
135135
-------
@@ -178,7 +178,7 @@ def read_sql_query(
178178
**_transaction_id(transaction_id=transaction_id, query_as_of_time=query_as_of_time, DatabaseName=database),
179179
)
180180
query_id: str = client_lakeformation.start_query_planning(QueryString=sql, QueryPlanningContext=args)["QueryId"]
181-
arrow_kwargs = _data_types.pyarrow2pandas_defaults(use_threads=use_threads, kwargs=arrow_additional_kwargs)
181+
arrow_kwargs = _data_types.pyarrow2pandas_defaults(use_threads=use_threads, kwargs=pyarrow_additional_kwargs)
182182
df = _resolve_sql_query(
183183
query_id=query_id,
184184
use_threads=use_threads,
@@ -199,7 +199,7 @@ def read_sql_table(
199199
catalog_id: Optional[str] = None,
200200
use_threads: bool = True,
201201
boto3_session: Optional[boto3.Session] = None,
202-
arrow_additional_kwargs: Optional[Dict[str, Any]] = None,
202+
pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None,
203203
) -> pd.DataFrame:
204204
"""Extract all rows from AWS Glue Table (Transaction ID or time travel timestamp). Return Pandas DataFrame.
205205
@@ -232,10 +232,10 @@ def read_sql_table(
232232
When enabled, os.cpu_count() is used as the max number of threads.
233233
boto3_session : boto3.Session(), optional
234234
Boto3 Session. The default boto3 session is used if boto3_session receives None.
235-
arrow_additional_kwargs : Dict[str, Any], optional
235+
pyarrow_additional_kwargs : Dict[str, Any], optional
236236
Forwarded to `to_pandas` method converting from PyArrow tables to Pandas dataframe.
237237
Valid values include "split_blocks", "self_destruct", "ignore_metadata".
238-
e.g. arrow_additional_kwargs={'split_blocks': True}.
238+
e.g. pyarrow_additional_kwargs={'split_blocks': True}.
239239
240240
Returns
241241
-------
@@ -276,5 +276,5 @@ def read_sql_table(
276276
catalog_id=catalog_id,
277277
use_threads=use_threads,
278278
boto3_session=boto3_session,
279-
arrow_additional_kwargs=arrow_additional_kwargs,
279+
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
280280
)

0 commit comments

Comments
 (0)