Skip to content

Commit 875e2ce

Browse files
authored
(perf): Distribute data types inference (#1692)
* (perf): Distribute data types inference * PR feedback: infer datatype based on first block
1 parent c0e75b9 commit 875e2ce

File tree

6 files changed

+35
-8
lines changed

6 files changed

+35
-8
lines changed

awswrangler/_data_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import pyarrow.parquet
1414

1515
from awswrangler import _utils, exceptions
16+
from awswrangler._distributed import engine
1617

1718
_logger: logging.Logger = logging.getLogger(__name__)
1819

@@ -456,6 +457,7 @@ def pyarrow2pandas_extension( # pylint: disable=too-many-branches,too-many-retu
456457
return None
457458

458459

460+
@engine.dispatch_on_engine
459461
def pyarrow_types_from_pandas( # pylint: disable=too-many-branches,too-many-statements
460462
df: pd.DataFrame, index: bool, ignore_cols: Optional[List[str]] = None, index_left: bool = False
461463
) -> Dict[str, pa.DataType]:

awswrangler/distributed/ray/_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
7474
return wrapper
7575

7676

77-
def ray_get(futures: List[Any]) -> List[Any]:
77+
def ray_get(futures: Union["ray.ObjectRef[Any]", List["ray.ObjectRef[Any]"]]) -> Any:
7878
"""
7979
Run ray.get on futures if distributed.
8080

awswrangler/distributed/ray/_core.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class RayLogger:
1414

1515
def ray_logger(function: Callable[..., Any]) -> Callable[..., Any]: ...
1616
def ray_remote(function: Callable[..., Any]) -> Callable[..., Any]: ...
17-
def ray_get(futures: List[Any]) -> List[Any]: ...
17+
def ray_get(futures: List[Any]) -> Any: ...
1818
def initialize_ray(
1919
address: Optional[str] = None,
2020
redis_password: Optional[str] = None,

awswrangler/distributed/ray/_register.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Ray and Modin registered methods (PRIVATE)."""
22
# pylint: disable=import-outside-toplevel
3+
from awswrangler._data_types import pyarrow_types_from_pandas
34
from awswrangler._distributed import MemoryFormatEnum, engine, memory_format
45
from awswrangler._utils import table_refs_to_df
56
from awswrangler.distributed.ray._core import ray_remote
@@ -28,6 +29,7 @@ def register_ray() -> None:
2829

2930
if memory_format.get() == MemoryFormatEnum.MODIN:
3031
from awswrangler.distributed.ray.modin._core import modin_repartition
32+
from awswrangler.distributed.ray.modin._data_types import pyarrow_types_from_pandas_distributed
3133
from awswrangler.distributed.ray.modin._utils import _arrow_refs_to_df
3234
from awswrangler.distributed.ray.modin.s3._read_parquet import _read_parquet_distributed
3335
from awswrangler.distributed.ray.modin.s3._read_text import _read_text_distributed
@@ -39,6 +41,7 @@ def register_ray() -> None:
3941
from awswrangler.distributed.ray.modin.s3._write_text import _to_text_distributed
4042

4143
for o_f, d_f in {
44+
pyarrow_types_from_pandas: pyarrow_types_from_pandas_distributed,
4245
_read_parquet: _read_parquet_distributed,
4346
_read_text: _read_text_distributed,
4447
_to_buckets: _to_buckets_distributed,
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Internal (private) Data Types Module."""
2+
from typing import Dict, List, Optional
3+
4+
import modin.pandas as pd
5+
import pyarrow as pa
6+
import ray
7+
8+
from awswrangler._data_types import pyarrow_types_from_pandas
9+
from awswrangler.distributed.ray._core import ray_get, ray_remote
10+
11+
12+
def pyarrow_types_from_pandas_distributed(
13+
df: pd.DataFrame, index: bool, ignore_cols: Optional[List[str]] = None, index_left: bool = False
14+
) -> Dict[str, pa.DataType]:
15+
"""Extract the related Pyarrow data types from a pandas DataFrame."""
16+
func = ray_remote(pyarrow_types_from_pandas)
17+
first_block_object_ref = ray.data.from_modin(df).get_internal_block_refs()[0]
18+
return ray_get( # type: ignore
19+
func(
20+
df=first_block_object_ref,
21+
index=index,
22+
ignore_cols=ignore_cols,
23+
index_left=index_left,
24+
)
25+
)

awswrangler/s3/_select.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import json
66
import logging
77
import pprint
8-
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union
8+
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
99

1010
import boto3
1111
import pandas as pd
@@ -19,9 +19,6 @@
1919
from awswrangler.s3._list import _path2list
2020
from awswrangler.s3._read import _get_path_ignore_suffix
2121

22-
if TYPE_CHECKING:
23-
import ray # pylint: disable=unused-import
24-
2522
_logger: logging.Logger = logging.getLogger(__name__)
2623

2724
_RANGE_CHUNK_SIZE: int = int(1024 * 1024)
@@ -42,7 +39,7 @@ def _select_object_content(
4239
boto3_session: Optional[boto3.Session],
4340
args: Dict[str, Any],
4441
scan_range: Optional[Tuple[int, int]] = None,
45-
) -> Union[pa.Table, "ray.ObjectRef[pa.Table]"]:
42+
) -> pa.Table:
4643
client_s3: boto3.client = _utils.client(service_name="s3", session=boto3_session)
4744

4845
if scan_range:
@@ -83,7 +80,7 @@ def _select_query(
8380
scan_range_chunk_size: Optional[int] = None,
8481
boto3_session: Optional[boto3.Session] = None,
8582
s3_additional_kwargs: Optional[Dict[str, Any]] = None,
86-
) -> List[Union[pa.Table, "ray.ObjectRef[pa.Table]"]]:
83+
) -> List[pa.Table]:
8784
bucket, key = _utils.parse_path(path)
8885

8986
args: Dict[str, Any] = {

0 commit comments

Comments
 (0)