11"""Modin on Ray Core module (PRIVATE)."""
22# pylint: disable=import-outside-toplevel
3+ import logging
34from functools import wraps
45from typing import Any , Callable , Optional
56
7+ import numpy as np
68import pandas as pd
79from modin .distributed .dataframe .pandas import from_partitions , unwrap_partitions
810from modin .pandas import DataFrame as ModinDataFrame
911
12+ _logger : logging .Logger = logging .getLogger (__name__ )
13+
14+
15+ def _validate_partition_shape (df : pd .DataFrame ) -> bool :
16+ """
17+ Validate if partitions of the data frame are partitioned along row axis.
18+
19+ Parameters
20+ ----------
21+ df : pd.DataFrame
22+ Modin data frame
23+
24+ Returns
25+ -------
26+ bool
27+ """
28+ # Unwrap partitions as they are currently stored (axis=None)
29+ partitions_shape = np .array (unwrap_partitions (df )).shape
30+ return partitions_shape [1 ] == 1
31+
1032
1133def modin_repartition (function : Callable [..., Any ]) -> Callable [..., Any ]:
1234 """
@@ -31,14 +53,23 @@ def modin_repartition(function: Callable[..., Any]) -> Callable[..., Any]:
3153 def wrapper (
3254 df : pd .DataFrame ,
3355 * args : Any ,
34- axis : int = 0 ,
56+ axis : Optional [ int ] = None ,
3557 row_lengths : Optional [int ] = None ,
58+ validate_partitions : bool = True ,
3659 ** kwargs : Any ,
3760 ) -> Any :
38- # Repartition Modin data frame along row (axis=0) axis
61+ # Validate partitions and repartition Modin data frame along row (axis=0) axis
3962 # to avoid a situation where columns are split along multiple blocks
40- if isinstance (df , ModinDataFrame ) and axis is not None :
41- df = from_partitions (unwrap_partitions (df , axis = axis ), axis = axis , row_lengths = row_lengths )
63+ if isinstance (df , ModinDataFrame ):
64+ if validate_partitions and not _validate_partition_shape (df ):
65+ _logger .warning (
66+ "Partitions of this data frame are detected to be split along column axis. "
67+ "The dataframe will be automatically repartitioned along row axis to ensure "
68+ "each partition can be processed independently."
69+ )
70+ axis = 0
71+ if axis is not None :
72+ df = from_partitions (unwrap_partitions (df , axis = axis ), axis = axis , row_lengths = row_lengths )
4273 return function (df , * args , ** kwargs )
4374
4475 return wrapper
0 commit comments