Skip to content

Commit b22b93d

Browse files
jeffreykennethlidevin-petersohn
andauthored
FIX-#4464: Refactor Ray utils and quick fix groupby.count failing on virtual partitions (#4490)
Co-authored-by: Devin Petersohn <[email protected]> Signed-off-by: jeffreykennethli <[email protected]>
1 parent dcee13d commit b22b93d

File tree

7 files changed

+148
-60
lines changed

7 files changed

+148
-60
lines changed

docs/release_notes/release_notes-0.15.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ Key Features and Updates
2727
* FIX-#4481: Allow clipping with a Modin Series of bounds (#4486)
2828
* FIX-#4504: Support na_action in applymap (#4505)
2929
* FIX-#4503: Stop the memory logging thread after session exit (#4515)
30+
* FIX-#4464: Refactor Ray utils and quick fix groupby.count failing on virtual partitions (#4490)
3031
* Performance enhancements
3132
* FEAT-#4320: Add connectorx as an alternative engine for read_sql (#4346)
3233
* PERF-#4493: Use partition size caches more in Modin dataframe (#4495)

modin/core/dataframe/pandas/partitioning/axis_partition.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def apply(
8383
num_splits,
8484
len(self.list_of_blocks),
8585
other_shape,
86-
kwargs,
8786
*tuple(
8887
self.list_of_blocks
8988
+ [
@@ -92,11 +91,12 @@ def apply(
9291
for part in axis_partition.list_of_blocks
9392
]
9493
),
94+
**kwargs,
9595
)
9696
)
97-
args = [self.axis, func, num_splits, kwargs, maintain_partitioning]
97+
args = [self.axis, func, num_splits, maintain_partitioning]
9898
args.extend(self.list_of_blocks)
99-
return self._wrap_partitions(self.deploy_axis_func(*args))
99+
return self._wrap_partitions(self.deploy_axis_func(*args, **kwargs))
100100

101101
def shuffle(self, func, lengths, **kwargs):
102102
"""
@@ -120,13 +120,13 @@ def shuffle(self, func, lengths, **kwargs):
120120
# We add these to kwargs and will pop them off before performing the operation.
121121
kwargs["manual_partition"] = True
122122
kwargs["_lengths"] = lengths
123-
args = [self.axis, func, num_splits, kwargs, False]
123+
args = [self.axis, func, num_splits, False]
124124
args.extend(self.list_of_blocks)
125-
return self._wrap_partitions(self.deploy_axis_func(*args))
125+
return self._wrap_partitions(self.deploy_axis_func(*args, **kwargs))
126126

127127
@classmethod
128128
def deploy_axis_func(
129-
cls, axis, func, num_splits, kwargs, maintain_partitioning, *partitions
129+
cls, axis, func, num_splits, maintain_partitioning, *partitions, **kwargs
130130
):
131131
"""
132132
Deploy a function along a full axis.
@@ -139,13 +139,13 @@ def deploy_axis_func(
139139
The function to perform.
140140
num_splits : int
141141
The number of splits to return (see `split_result_of_axis_func_pandas`).
142-
kwargs : dict
143-
Additional keywords arguments to be passed in `func`.
144142
maintain_partitioning : bool
145143
If True, keep the old partitioning if possible.
146144
If False, create a new partition layout.
147145
*partitions : iterable
148146
All partitions that make up the full axis (row or column).
147+
**kwargs : dict
148+
Additional keywords arguments to be passed in `func`.
149149
150150
Returns
151151
-------
@@ -157,7 +157,9 @@ def deploy_axis_func(
157157
lengths = kwargs.pop("_lengths", None)
158158

159159
dataframe = pandas.concat(list(partitions), axis=axis, copy=False)
160-
result = func(dataframe, **kwargs)
160+
# To not mix the args for deploy_axis_func and args for func, we fold
161+
# args into kwargs. This is a bit of a hack, but it works.
162+
result = func(dataframe, *kwargs.pop("args", ()), **kwargs)
161163

162164
if manual_partition:
163165
# The split function is expecting a list
@@ -180,7 +182,14 @@ def deploy_axis_func(
180182

181183
@classmethod
182184
def deploy_func_between_two_axis_partitions(
183-
cls, axis, func, num_splits, len_of_left, other_shape, kwargs, *partitions
185+
cls,
186+
axis,
187+
func,
188+
num_splits,
189+
len_of_left,
190+
other_shape,
191+
*partitions,
192+
**kwargs,
184193
):
185194
"""
186195
Deploy a function along a full axis between two data sets.
@@ -198,10 +207,10 @@ def deploy_func_between_two_axis_partitions(
198207
other_shape : np.ndarray
199208
The shape of right frame in terms of partitions, i.e.
200209
(other_shape[i-1], other_shape[i]) will indicate slice to restore i-1 axis partition.
201-
kwargs : dict
202-
Additional keywords arguments to be passed in `func`.
203210
*partitions : iterable
204211
All partitions that make up the full axis (row or column) for both data sets.
212+
**kwargs : dict
213+
Additional keywords arguments to be passed in `func`.
205214
206215
Returns
207216
-------
@@ -222,6 +231,7 @@ def deploy_func_between_two_axis_partitions(
222231
for i in range(1, len(other_shape))
223232
]
224233
rt_frame = pandas.concat(combined_axis, axis=axis ^ 1, copy=False)
225-
226-
result = func(lt_frame, rt_frame, **kwargs)
234+
# To not mix the args for deploy_func_between_two_axis_partitions and args
235+
# for func, we fold args into kwargs. This is a bit of a hack, but it works.
236+
result = func(lt_frame, rt_frame, *kwargs.pop("args", ()), **kwargs)
227237
return split_result_of_axis_func_pandas(axis, num_splits, result)

modin/core/execution/dask/implementations/pandas_on_dask/partitioning/virtual_partition.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,13 @@ def __init__(self, list_of_blocks, get_ip=False, full_axis=True):
5555

5656
@classmethod
5757
def deploy_axis_func(
58-
cls, axis, func, num_splits, kwargs, maintain_partitioning, *partitions
58+
cls,
59+
axis,
60+
func,
61+
num_splits,
62+
maintain_partitioning,
63+
*partitions,
64+
**kwargs,
5965
):
6066
"""
6167
Deploy a function along a full axis.
@@ -68,13 +74,13 @@ def deploy_axis_func(
6874
The function to perform.
6975
num_splits : int
7076
The number of splits to return (see `split_result_of_axis_func_pandas`).
71-
kwargs : dict
72-
Additional keywords arguments to be passed in `func`.
7377
maintain_partitioning : bool
7478
If True, keep the old partitioning if possible.
7579
If False, create a new partition layout.
7680
*partitions : iterable
7781
All partitions that make up the full axis (row or column).
82+
**kwargs : dict
83+
Additional keywords arguments to be passed in `func`.
7884
7985
Returns
8086
-------
@@ -89,16 +95,16 @@ def deploy_axis_func(
8995
axis,
9096
func,
9197
num_splits,
92-
kwargs,
9398
maintain_partitioning,
9499
*partitions,
95100
num_returns=result_num_splits * 4,
96101
pure=False,
102+
**kwargs,
97103
)
98104

99105
@classmethod
100106
def deploy_func_between_two_axis_partitions(
101-
cls, axis, func, num_splits, len_of_left, other_shape, kwargs, *partitions
107+
cls, axis, func, num_splits, len_of_left, other_shape, *partitions, **kwargs
102108
):
103109
"""
104110
Deploy a function along a full axis between two data sets.
@@ -116,10 +122,10 @@ def deploy_func_between_two_axis_partitions(
116122
other_shape : np.ndarray
117123
The shape of right frame in terms of partitions, i.e.
118124
(other_shape[i-1], other_shape[i]) will indicate slice to restore i-1 axis partition.
119-
kwargs : dict
120-
Additional keywords arguments to be passed in `func`.
121125
*partitions : iterable
122126
All partitions that make up the full axis (row or column) for both data sets.
127+
**kwargs : dict
128+
Additional keywords arguments to be passed in `func`.
123129
124130
Returns
125131
-------
@@ -134,10 +140,10 @@ def deploy_func_between_two_axis_partitions(
134140
num_splits,
135141
len_of_left,
136142
other_shape,
137-
kwargs,
138143
*partitions,
139144
num_returns=num_splits * 4,
140145
pure=False,
146+
**kwargs,
141147
)
142148

143149
def _wrap_partitions(self, partitions):
@@ -200,7 +206,7 @@ class PandasOnDaskDataframeRowPartition(PandasOnDaskDataframeAxisPartition):
200206
axis = 1
201207

202208

203-
def deploy_dask_func(func, *args):
209+
def deploy_dask_func(func, *args, **kwargs):
204210
"""
205211
Execute a function on an axis partition in a worker process.
206212
@@ -210,13 +216,15 @@ def deploy_dask_func(func, *args):
210216
Function to be executed on an axis partition.
211217
*args : iterable
212218
Additional arguments that need to passed in ``func``.
219+
**kwargs : dict
220+
Additional keyword arguments to be passed in `func`.
213221
214222
Returns
215223
-------
216224
list
217225
The result of the function ``func`` and metadata for it.
218226
"""
219-
result = func(*args)
227+
result = func(*args, **kwargs)
220228
ip = get_ip()
221229
if isinstance(result, pandas.DataFrame):
222230
return result, len(result), len(result.columns), ip

modin/core/execution/ray/common/utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import sys
1818
import psutil
19+
from packaging import version
1920
import warnings
2021

2122
import ray
@@ -32,6 +33,12 @@
3233
ValueSource,
3334
)
3435

36+
ObjectIDType = ray.ObjectRef
37+
if version.parse(ray.__version__) >= version.parse("1.2.0"):
38+
from ray.util.client.common import ClientObjectRef
39+
40+
ObjectIDType = (ray.ObjectRef, ClientObjectRef)
41+
3542

3643
def _move_stdlib_ahead_of_site_packages(*args):
3744
"""
@@ -223,3 +230,31 @@ def initialize_ray(
223230
NPartitions._put(num_gpus)
224231
else:
225232
NPartitions._put(num_cpus)
233+
234+
235+
def deserialize(obj):
236+
"""
237+
Deserialize a Ray object.
238+
239+
Parameters
240+
----------
241+
obj : ObjectIDType, iterable of ObjectIDType, or mapping of keys to ObjectIDTypes
242+
Object(s) to deserialize.
243+
244+
Returns
245+
-------
246+
obj
247+
The deserialized object.
248+
"""
249+
if isinstance(obj, ObjectIDType):
250+
return ray.get(obj)
251+
elif isinstance(obj, (tuple, list)) and any(
252+
isinstance(o, ObjectIDType) for o in obj
253+
):
254+
return ray.get(list(obj))
255+
elif isinstance(obj, dict) and any(
256+
isinstance(val, ObjectIDType) for val in obj.values()
257+
):
258+
return dict(zip(obj.keys(), ray.get(list(obj.values()))))
259+
else:
260+
return obj

modin/core/execution/ray/implementations/pandas_on_ray/partitioning/partition.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,13 @@
1515

1616
import ray
1717
from ray.util import get_node_ip_address
18-
from packaging import version
1918
import uuid
19+
from modin.core.execution.ray.common.utils import deserialize, ObjectIDType
2020

2121
from modin.core.dataframe.pandas.partitioning.partition import PandasDataframePartition
2222
from modin.pandas.indexing import compute_sliced_len
2323
from modin.logging import get_logger
2424

25-
ObjectIDType = ray.ObjectRef
26-
if version.parse(ray.__version__) >= version.parse("1.2.0"):
27-
from ray.util.client.common import ClientObjectRef
28-
29-
ObjectIDType = (ray.ObjectRef, ClientObjectRef)
30-
3125
compute_sliced_len = ray.remote(compute_sliced_len)
3226

3327

@@ -419,21 +413,6 @@ def _apply_list_of_funcs(funcs, partition): # pragma: no cover
419413
str
420414
The node IP address of the worker process.
421415
"""
422-
423-
def deserialize(obj):
424-
if isinstance(obj, ObjectIDType):
425-
return ray.get(obj)
426-
elif isinstance(obj, (tuple, list)) and any(
427-
isinstance(o, ObjectIDType) for o in obj
428-
):
429-
return ray.get(list(obj))
430-
elif isinstance(obj, dict) and any(
431-
isinstance(val, ObjectIDType) for val in obj.values()
432-
):
433-
return dict(zip(obj.keys(), ray.get(list(obj.values()))))
434-
else:
435-
return obj
436-
437416
for func, args, kwargs in funcs:
438417
func = deserialize(func)
439418
args = deserialize(args)

0 commit comments

Comments
 (0)