Skip to content

Commit 4f4a73c

Browse files
authored
(feat) validate partitions along row axis, add warning (#1700)
* Check df is partitioned along row axis, add warning * Fix - pass axis * PR feedback
1 parent f38a00e commit 4f4a73c

File tree

1 file changed

+35
-4
lines changed
  • awswrangler/distributed/ray/modin

1 file changed

+35
-4
lines changed

awswrangler/distributed/ray/modin/_core.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,34 @@
11
"""Modin on Ray Core module (PRIVATE)."""
22
# pylint: disable=import-outside-toplevel
3+
import logging
34
from functools import wraps
45
from typing import Any, Callable, Optional
56

7+
import numpy as np
68
import pandas as pd
79
from modin.distributed.dataframe.pandas import from_partitions, unwrap_partitions
810
from modin.pandas import DataFrame as ModinDataFrame
911

12+
_logger: logging.Logger = logging.getLogger(__name__)
13+
14+
15+
def _validate_partition_shape(df: pd.DataFrame) -> bool:
16+
"""
17+
Validate if partitions of the data frame are partitioned along row axis.
18+
19+
Parameters
20+
----------
21+
df : pd.DataFrame
22+
Modin data frame
23+
24+
Returns
25+
-------
26+
bool
27+
"""
28+
# Unwrap partitions as they are currently stored (axis=None)
29+
partitions_shape = np.array(unwrap_partitions(df)).shape
30+
return partitions_shape[1] == 1
31+
1032

1133
def modin_repartition(function: Callable[..., Any]) -> Callable[..., Any]:
1234
"""
@@ -31,14 +53,23 @@ def modin_repartition(function: Callable[..., Any]) -> Callable[..., Any]:
3153
def wrapper(
3254
df: pd.DataFrame,
3355
*args: Any,
34-
axis: int = 0,
56+
axis: Optional[int] = None,
3557
row_lengths: Optional[int] = None,
58+
validate_partitions: bool = True,
3659
**kwargs: Any,
3760
) -> Any:
38-
# Repartition Modin data frame along row (axis=0) axis
61+
# Validate partitions and repartition Modin data frame along row (axis=0) axis
3962
# to avoid a situation where columns are split along multiple blocks
40-
if isinstance(df, ModinDataFrame) and axis is not None:
41-
df = from_partitions(unwrap_partitions(df, axis=axis), axis=axis, row_lengths=row_lengths)
63+
if isinstance(df, ModinDataFrame):
64+
if validate_partitions and not _validate_partition_shape(df):
65+
_logger.warning(
66+
"Partitions of this data frame are detected to be split along column axis. "
67+
"The dataframe will be automatically repartitioned along row axis to ensure "
68+
"each partition can be processed independently."
69+
)
70+
axis = 0
71+
if axis is not None:
72+
df = from_partitions(unwrap_partitions(df, axis=axis), axis=axis, row_lengths=row_lengths)
4273
return function(df, *args, **kwargs)
4374

4475
return wrapper

0 commit comments

Comments
 (0)