Skip to content

Commit 5e91b19

Browse files
authored
fix: distributed write text regression, change to singledispatch, add repartitioning utility (#1611)
* Fix distributed write text regression * Try out singledispatch * Minor fixes * Refactoring * Fix write args order * Fix imports * Fix import Modin df
1 parent fc612fc commit 5e91b19

File tree

8 files changed

+165
-98
lines changed

8 files changed

+165
-98
lines changed
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Distributed Module."""
22

3-
from awswrangler.distributed._distributed import initialize_ray, ray_get, ray_remote # noqa
3+
from awswrangler.distributed._distributed import initialize_ray, modin_repartition, ray_get, ray_remote # noqa
44

55
__all__ = [
66
"initialize_ray",
77
"ray_get",
88
"ray_remote",
9+
"modin_repartition",
910
]

awswrangler/distributed/_distributed.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
"""Distributed Module (PRIVATE)."""
22

3+
import logging
34
import multiprocessing
45
import os
56
import sys
67
import warnings
8+
from functools import wraps
79
from typing import TYPE_CHECKING, Any, Callable, List, Optional
810

911
from awswrangler._config import apply_configs, config
1012

1113
if config.distributed or TYPE_CHECKING:
1214
import psutil
1315
import ray # pylint: disable=import-error
16+
from modin.distributed.dataframe.pandas import from_partitions, unwrap_partitions
17+
from modin.pandas import DataFrame as ModinDataFrame
18+
19+
_logger: logging.Logger = logging.getLogger(__name__)
1420

1521

1622
def ray_get(futures: List[Any]) -> List[Any]:
@@ -46,13 +52,42 @@ def ray_remote(function: Callable[..., Any]) -> Callable[..., Any]:
4652
"""
4753
if config.distributed:
4854

55+
@wraps(function)
4956
def wrapper(*args: Any, **kwargs: Any) -> Any:
5057
return ray.remote(function).remote(*args, **kwargs)
5158

5259
return wrapper
5360
return function
5461

5562

63+
def modin_repartition(function: Callable[..., Any]) -> Callable[..., Any]:
64+
"""
65+
Decorate callable to repartition Modin data frame.
66+
67+
By default, repartition along row (axis=0) axis.
68+
This avoids a situation where columns are split along multiple blocks.
69+
70+
Parameters
71+
----------
72+
function : Callable[..., Any]
73+
Callable as input to ray.remote
74+
75+
Returns
76+
-------
77+
Callable[..., Any]
78+
"""
79+
80+
@wraps(function)
81+
def wrapper(df, *args: Any, axis=0, row_lengths=None, **kwargs: Any) -> Any:
82+
if config.distributed and isinstance(df, ModinDataFrame) and axis is not None:
83+
# Repartition Modin data frame along row (axis=0) axis
84+
# to avoid a situation where columns are split along multiple blocks
85+
df = from_partitions(unwrap_partitions(df, axis=axis), axis=axis, row_lengths=row_lengths)
86+
return function(df, *args, **kwargs)
87+
88+
return wrapper
89+
90+
5691
@apply_configs
5792
def initialize_ray(
5893
address: Optional[str] = None,

awswrangler/distributed/_distributed.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ from awswrangler._config import apply_configs, config
1010

1111
def ray_get(futures: List[Any]) -> List[Any]: ...
1212
def ray_remote(function: Callable[..., Any]) -> Callable[..., Any]: ...
13+
def modin_repartition(function: Callable[..., Any]) -> Callable[..., Any]: ...
1314
def initialize_ray(
1415
address: Optional[str] = None,
1516
redis_password: Optional[str] = None,

awswrangler/s3/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Amazon S3 Read Module."""
22

3+
from awswrangler._config import config
34
from awswrangler.s3._copy import copy_objects, merge_datasets # noqa
45
from awswrangler.s3._delete import delete_objects # noqa
56
from awswrangler.s3._describe import describe_objects, get_bucket_region, size_objects # noqa
@@ -45,3 +46,18 @@
4546
"download",
4647
"upload",
4748
]
49+
50+
if config.distributed:
51+
from modin.pandas import DataFrame as ModinDataFrame
52+
53+
from awswrangler.s3._write_dataset import ( # pylint: disable=ungrouped-imports
54+
_to_buckets,
55+
_to_buckets_distributed,
56+
_to_partitions,
57+
_to_partitions_distributed,
58+
)
59+
from awswrangler.s3._write_parquet import _to_parquet, _to_parquet_distributed # pylint: disable=ungrouped-imports
60+
61+
_to_parquet.register(ModinDataFrame, _to_parquet_distributed)
62+
_to_buckets.register(ModinDataFrame, _to_buckets_distributed)
63+
_to_partitions.register(ModinDataFrame, _to_partitions_distributed)

awswrangler/s3/_write_concurrent.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,27 +25,33 @@ def __init__(self, use_threads: Union[bool, int]):
2525

2626
@staticmethod
2727
def _caller(
28-
func: Callable[..., pd.DataFrame], boto3_primitives: _utils.Boto3PrimitivesType, func_kwargs: Dict[str, Any]
28+
func: Callable[..., pd.DataFrame],
29+
boto3_primitives: _utils.Boto3PrimitivesType,
30+
*args: Any,
31+
func_kwargs: Dict[str, Any],
2932
) -> pd.DataFrame:
3033
boto3_session: boto3.Session = _utils.boto3_from_primitives(primitives=boto3_primitives)
3134
func_kwargs["boto3_session"] = boto3_session
3235
_logger.debug("Calling: %s", func)
33-
return func(**func_kwargs)
36+
return func(*args, **func_kwargs)
3437

35-
def write(self, func: Callable[..., List[str]], boto3_session: boto3.Session, **func_kwargs: Any) -> None:
38+
def write(
39+
self, func: Callable[..., List[str]], boto3_session: boto3.Session, *args: Any, **func_kwargs: Any
40+
) -> None:
3641
"""Write File."""
3742
if self._exec is not None:
3843
_utils.block_waiting_available_thread(seq=self._futures, max_workers=self._cpus)
3944
_logger.debug("Submitting: %s", func)
4045
future = self._exec.submit(
4146
_WriteProxy._caller,
42-
func=func,
43-
boto3_primitives=_utils.boto3_to_primitives(boto3_session=boto3_session),
47+
func,
48+
_utils.boto3_to_primitives(boto3_session=boto3_session),
49+
*args,
4450
func_kwargs=func_kwargs,
4551
)
4652
self._futures.append(future)
4753
else:
48-
self._results += func(boto3_session=boto3_session, **func_kwargs)
54+
self._results += func(*args, boto3_session=boto3_session, **func_kwargs)
4955

5056
def close(self) -> List[str]:
5157
"""Close the proxy."""

awswrangler/s3/_write_dataset.py

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Amazon S3 Write Dataset (PRIVATE)."""
22

33
import logging
4+
from functools import singledispatch
45
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
56

67
import boto3
@@ -13,19 +14,13 @@
1314

1415
if config.distributed:
1516
import modin.pandas as pd
16-
from modin.distributed.dataframe.pandas import from_partitions, unwrap_partitions
1717
from modin.pandas import DataFrame as ModinDataFrame
1818
else:
1919
import pandas as pd
2020

2121
_logger: logging.Logger = logging.getLogger(__name__)
2222

2323

24-
def _get_subgroup_prefix(keys: Tuple[str, None], partition_cols: List[str], path_root: str) -> str:
25-
subdir = "/".join([f"{name}={val}" for name, val in zip(partition_cols, keys)])
26-
return f"{path_root}{subdir}/"
27-
28-
2924
def _get_bucketing_series(df: pd.DataFrame, bucketing_info: Tuple[List[str], int]) -> pd.Series:
3025
bucket_number_series = df.astype("O").apply(
3126
lambda row: _get_bucket_number(bucketing_info[1], [row[col_name] for col_name in bucketing_info[0]]),
@@ -75,6 +70,11 @@ def _get_value_hash(value: Union[str, int, bool]) -> int:
7570
)
7671

7772

73+
def _get_subgroup_prefix(keys: Tuple[str, None], partition_cols: List[str], path_root: str) -> str:
74+
subdir = "/".join([f"{name}={val}" for name, val in zip(partition_cols, keys)])
75+
return f"{path_root}{subdir}/"
76+
77+
7878
def _delete_objects(
7979
keys: Tuple[str, None],
8080
path_root: str,
@@ -168,7 +168,7 @@ def _write_partitions_distributed(
168168
)
169169
else:
170170
paths = write_func( # type: ignore
171-
df=df_group.drop(partition_cols, axis="columns"),
171+
df_group.drop(partition_cols, axis="columns"),
172172
path_root=prefix,
173173
filename_prefix=filename_prefix,
174174
boto3_session=boto3_session,
@@ -178,10 +178,11 @@ def _write_partitions_distributed(
178178
return prefix, df_group.name, paths
179179

180180

181+
@singledispatch
181182
def _to_partitions(
183+
df: pd.DataFrame,
182184
func: Callable[..., List[str]],
183185
concurrent_partitioning: bool,
184-
df: pd.DataFrame,
185186
path_root: str,
186187
use_threads: Union[bool, int],
187188
mode: str,
@@ -221,8 +222,8 @@ def _to_partitions(
221222
)
222223
if bucketing_info:
223224
_to_buckets(
225+
subgroup,
224226
func=func,
225-
df=subgroup,
226227
path_root=prefix,
227228
bucketing_info=bucketing_info,
228229
boto3_session=boto3_session,
@@ -233,11 +234,11 @@ def _to_partitions(
233234
)
234235
else:
235236
proxy.write(
236-
func=func,
237-
df=subgroup,
237+
func,
238+
boto3_session,
239+
subgroup,
238240
path_root=prefix,
239241
filename_prefix=filename_prefix,
240-
boto3_session=boto3_session,
241242
use_threads=use_threads,
242243
**func_kwargs,
243244
)
@@ -247,9 +248,9 @@ def _to_partitions(
247248

248249

249250
def _to_partitions_distributed( # pylint: disable=unused-argument
251+
df: pd.DataFrame,
250252
func: Callable[..., List[str]],
251253
concurrent_partitioning: bool,
252-
df: pd.DataFrame,
253254
path_root: str,
254255
use_threads: Union[bool, int],
255256
mode: str,
@@ -283,7 +284,7 @@ def _to_partitions_distributed( # pylint: disable=unused-argument
283284
boto3_session=None,
284285
**func_kwargs,
285286
)
286-
paths: List[str] = [path for metadata in df_write_metadata.values for _, _, path in metadata]
287+
paths: List[str] = [path for metadata in df_write_metadata.values for _, _, paths in metadata for path in paths]
287288
partitions_values: Dict[str, List[str]] = {
288289
prefix: list(str(p) for p in partitions) if isinstance(partitions, tuple) else [str(partitions)]
289290
for metadata in df_write_metadata.values
@@ -292,9 +293,10 @@ def _to_partitions_distributed( # pylint: disable=unused-argument
292293
return paths, partitions_values
293294

294295

296+
@singledispatch
295297
def _to_buckets(
296-
func: Callable[..., List[str]],
297298
df: pd.DataFrame,
299+
func: Callable[..., List[str]],
298300
path_root: str,
299301
bucketing_info: Tuple[List[str], int],
300302
filename_prefix: str,
@@ -307,11 +309,11 @@ def _to_buckets(
307309
df_groups = df.groupby(by=_get_bucketing_series(df=df, bucketing_info=bucketing_info))
308310
for bucket_number, subgroup in df_groups:
309311
_proxy.write(
310-
func=func,
311-
df=subgroup,
312+
func,
313+
boto3_session,
314+
subgroup,
312315
path_root=path_root,
313316
filename_prefix=f"{filename_prefix}_bucket-{bucket_number:05d}",
314-
boto3_session=boto3_session,
315317
use_threads=use_threads,
316318
**func_kwargs,
317319
)
@@ -322,8 +324,8 @@ def _to_buckets(
322324

323325

324326
def _to_buckets_distributed( # pylint: disable=unused-argument
325-
func: Callable[..., List[str]],
326327
df: pd.DataFrame,
328+
func: Callable[..., List[str]],
327329
path_root: str,
328330
bucketing_info: Tuple[List[str], int],
329331
filename_prefix: str,
@@ -335,7 +337,7 @@ def _to_buckets_distributed( # pylint: disable=unused-argument
335337
df_groups = df.groupby(by=_get_bucketing_series(df=df, bucketing_info=bucketing_info))
336338
paths: List[str] = []
337339
df_paths = df_groups.apply(
338-
func,
340+
func.dispatch(ModinDataFrame), # type: ignore
339341
path_root=path_root,
340342
filename_prefix=filename_prefix,
341343
boto3_session=None,
@@ -398,24 +400,14 @@ def _to_dataset(
398400
else:
399401
delete_objects(path=path_root, use_threads=use_threads, boto3_session=boto3_session)
400402

401-
_to_partitions_fn: Callable[..., Tuple[List[str], Dict[str, List[str]]]] = _to_partitions
402-
_to_buckets_fn: Callable[..., List[str]] = _to_buckets
403-
if config.distributed and isinstance(df, ModinDataFrame):
404-
# Ensure Modin dataframe is partitioned along row axis
405-
# It avoids a situation where columns are split along multiple blocks
406-
df = from_partitions(unwrap_partitions(df, axis=0), axis=0)
407-
408-
_to_partitions_fn = _to_partitions_distributed
409-
_to_buckets_fn = _to_buckets_distributed
410-
411403
# Writing
412404
partitions_values: Dict[str, List[str]] = {}
413405
paths: List[str]
414406
if partition_cols:
415-
paths, partitions_values = _to_partitions_fn(
407+
paths, partitions_values = _to_partitions(
408+
df,
416409
func=func,
417410
concurrent_partitioning=concurrent_partitioning,
418-
df=df,
419411
path_root=path_root,
420412
use_threads=use_threads,
421413
mode=mode,
@@ -433,9 +425,9 @@ def _to_dataset(
433425
**func_kwargs,
434426
)
435427
elif bucketing_info:
436-
paths = _to_buckets_fn(
428+
paths = _to_buckets(
429+
df,
437430
func=func,
438-
df=df,
439431
path_root=path_root,
440432
use_threads=use_threads,
441433
bucketing_info=bucketing_info,
@@ -446,7 +438,7 @@ def _to_dataset(
446438
)
447439
else:
448440
paths = func(
449-
df=df,
441+
df,
450442
path_root=path_root,
451443
filename_prefix=filename_prefix,
452444
use_threads=use_threads,

0 commit comments

Comments
 (0)