Skip to content

Commit 7fd7302

Browse files
refactor: engine register_func to handle type checking (#2309)
1 parent e993c7d commit 7fd7302

File tree

8 files changed

+34
-35
lines changed

8 files changed

+34
-35
lines changed

awswrangler/_distributed.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,10 @@ def dispatch_func(cls, source_func: FunctionType, value: Optional[EngineLiteral]
8989
return getattr(source_func, "_source_func", source_func)
9090

9191
@classmethod
92-
def register_func(cls, source_func: Callable[..., Any], destination_func: Callable[..., Any]) -> Callable[..., Any]:
92+
def register_func(cls, source_func: FunctionType, destination_func: FunctionType) -> None:
9393
"""Register a func based on the distribution engine and source function."""
9494
with cls._lock:
9595
cls._registry[cls.get().value][source_func.__name__] = destination_func
96-
return destination_func
9796

9897
@classmethod
9998
def dispatch_on_engine(cls, func: FunctionType) -> FunctionType:

awswrangler/distributed/ray/_core.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import os
44
from functools import wraps
5-
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
5+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, TypeVar, Union
66

77
from awswrangler._config import apply_configs
88
from awswrangler._distributed import EngineEnum, engine
@@ -13,6 +13,9 @@
1313
_logger: logging.Logger = logging.getLogger(__name__)
1414

1515

16+
FunctionType = TypeVar("FunctionType", bound=Callable[..., Any])
17+
18+
1619
class RayLogger:
1720
"""Create discrete Logger instance for Ray Tasks."""
1821

@@ -31,10 +34,10 @@ def get_logger(self, name: Union[str, Any] = None) -> Optional[logging.Logger]:
3134

3235
@apply_configs
3336
def ray_logger(
34-
function: Callable[..., Any],
37+
function: FunctionType,
3538
configure_logging: bool = True,
3639
logging_level: int = logging.INFO,
37-
) -> Callable[..., Any]:
40+
) -> FunctionType:
3841
"""
3942
Decorate callable to add RayLogger.
4043
@@ -57,7 +60,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
5760
return wrapper
5861

5962

60-
def ray_remote(**options: Any) -> Callable[..., Any]:
63+
def ray_remote(**options: Any) -> Callable[[FunctionType], FunctionType]:
6164
"""
6265
Decorate with @ray.remote providing .options().
6366
@@ -71,7 +74,7 @@ def ray_remote(**options: Any) -> Callable[..., Any]:
7174
Callable[..., Any]
7275
"""
7376

74-
def remote_decorator(function: Callable[..., Any]) -> Callable[..., Any]:
77+
def remote_decorator(function: FunctionType) -> FunctionType:
7578
"""
7679
Decorate callable to wrap within ray.remote.
7780

awswrangler/distributed/ray/_register.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,11 @@ def register_ray() -> None:
5555
]:
5656
engine.register_func(func, ray_remote()(func))
5757

58-
for o_f, d_f in {
59-
_get_executor: _get_ray_executor,
60-
_list_objects_paginate: _list_objects_s3fs,
61-
_read_parquet_metadata_file: ray_remote()(_read_parquet_metadata_file_distributed),
62-
ensure_worker_or_thread_count: ensure_worker_count,
63-
}.items():
64-
engine.register_func(o_f, d_f) # type: ignore[arg-type]
58+
# Register dispatch methods for Ray
59+
engine.register_func(_get_executor, _get_ray_executor)
60+
engine.register_func(_list_objects_paginate, _list_objects_s3fs)
61+
engine.register_func(_read_parquet_metadata_file, ray_remote()(_read_parquet_metadata_file_distributed))
62+
engine.register_func(ensure_worker_or_thread_count, ensure_worker_count)
6563

6664
if memory_format.get() == MemoryFormatEnum.MODIN:
6765
from awswrangler.distributed.ray.modin._data_types import pyarrow_types_from_pandas_distributed
@@ -80,17 +78,15 @@ def register_ray() -> None:
8078
from awswrangler.distributed.ray.modin.s3._write_parquet import _to_parquet_distributed
8179
from awswrangler.distributed.ray.modin.s3._write_text import _to_text_distributed
8280

83-
for o_f, d_f in {
84-
pyarrow_types_from_pandas: pyarrow_types_from_pandas_distributed,
85-
_read_parquet: _read_parquet_distributed,
86-
_read_text: _read_text_distributed,
87-
_to_buckets: _to_buckets_distributed,
88-
_to_parquet: _to_parquet_distributed,
89-
_to_partitions: _to_partitions_distributed,
90-
_to_text: _to_text_distributed,
91-
copy_df_shallow: _copy_modin_df_shallow,
92-
is_pandas_frame: _is_pandas_or_modin_frame,
93-
split_pandas_frame: _split_modin_frame,
94-
table_refs_to_df: _arrow_refs_to_df,
95-
}.items():
96-
engine.register_func(o_f, d_f) # type: ignore[arg-type]
81+
# Register dispatch methods for Modin
82+
engine.register_func(pyarrow_types_from_pandas, pyarrow_types_from_pandas_distributed)
83+
engine.register_func(_read_parquet, _read_parquet_distributed)
84+
engine.register_func(_read_text, _read_text_distributed)
85+
engine.register_func(_to_buckets, _to_buckets_distributed)
86+
engine.register_func(_to_parquet, _to_parquet_distributed)
87+
engine.register_func(_to_partitions, _to_partitions_distributed)
88+
engine.register_func(_to_text, _to_text_distributed)
89+
engine.register_func(copy_df_shallow, _copy_modin_df_shallow)
90+
engine.register_func(is_pandas_frame, _is_pandas_or_modin_frame)
91+
engine.register_func(split_pandas_frame, _split_modin_frame)
92+
engine.register_func(table_refs_to_df, _arrow_refs_to_df)

awswrangler/distributed/ray/modin/s3/_read_parquet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Modin on Ray S3 read parquet module (PRIVATE)."""
2-
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
2+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
33

44
import modin.pandas as pd
55
import pyarrow as pa
@@ -37,7 +37,7 @@ def _read_parquet_distributed( # pylint: disable=unused-argument
3737
s3_additional_kwargs: Optional[Dict[str, Any]],
3838
arrow_kwargs: Dict[str, Any],
3939
bulk_read: bool,
40-
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
40+
) -> pd.DataFrame:
4141
dataset_kwargs = {}
4242
if coerce_int96_timestamp_unit:
4343
dataset_kwargs["coerce_int96_timestamp_unit"] = coerce_int96_timestamp_unit

awswrangler/distributed/ray/modin/s3/_read_text.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,14 @@ def _read_text_distributed( # pylint: disable=unused-argument
121121
read_format: str,
122122
paths: List[str],
123123
path_root: Optional[str],
124+
use_threads: Union[bool, int],
125+
s3_client: Optional["S3Client"],
124126
s3_additional_kwargs: Optional[Dict[str, str]],
125127
dataset: bool,
126128
ignore_index: bool,
127129
parallelism: int,
128130
version_ids: Optional[Dict[str, str]],
129131
pandas_kwargs: Dict[str, Any],
130-
use_threads: Union[bool, int],
131-
s3_client: Optional["S3Client"],
132132
) -> pd.DataFrame:
133133
try:
134134
configuration: Dict[str, Any] = _parse_configuration( # type: ignore[assignment]

awswrangler/s3/_read_parquet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def _read_parquet( # pylint: disable=W0613
331331
s3_additional_kwargs: Optional[Dict[str, Any]],
332332
arrow_kwargs: Dict[str, Any],
333333
bulk_read: bool,
334-
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
334+
) -> pd.DataFrame:
335335
executor: _BaseExecutor = _get_executor(use_threads=use_threads)
336336
tables = executor.map(
337337
_read_parquet_file,

awswrangler/s3/_read_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def _read_text( # pylint: disable=W0613
5252
parallelism: int,
5353
version_ids: Optional[Dict[str, str]],
5454
pandas_kwargs: Dict[str, Any],
55-
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
55+
) -> pd.DataFrame:
5656
parser_func = _resolve_format(read_format)
5757
executor: _BaseExecutor = _get_executor(use_threads=use_threads)
5858
tables = executor.map(

awswrangler/s3/_write_parquet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def _to_parquet_chunked(
171171

172172

173173
@engine.dispatch_on_engine
174-
def _to_parquet(
174+
def _to_parquet( # pylint: disable=unused-argument
175175
df: pd.DataFrame,
176176
schema: pa.Schema,
177177
index: bool,
@@ -187,6 +187,7 @@ def _to_parquet(
187187
path_root: Optional[str] = None,
188188
filename_prefix: Optional[str] = None,
189189
max_rows_by_file: Optional[int] = 0,
190+
bucketing: bool = False,
190191
) -> List[str]:
191192
s3_client = s3_client if s3_client else _utils.client(service_name="s3")
192193
file_path = _get_file_path(

0 commit comments

Comments
 (0)