Skip to content

Commit adb4161

Browse files
authored
Add support for dask.persist (#2953)
1 parent 8891beb commit adb4161

File tree

2 files changed

+68
-17
lines changed

2 files changed

+68
-17
lines changed

mars/contrib/dask/scheduler.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414

1515
from dask.core import istask, ishashable
1616

17-
from typing import List, Tuple
17+
from typing import List, Tuple, Union
1818
from .utils import reduce
1919
from ...remote import spawn
20+
from ...deploy.oscar.session import execute
2021

2122

22-
def mars_scheduler(dsk: dict, keys: List[List[str]]):
23+
def mars_scheduler(dsk: dict, keys: Union[List[List[str]], List[str]]):
2324
"""
2425
A Dask-Mars scheduler
2526
@@ -30,22 +31,29 @@ def mars_scheduler(dsk: dict, keys: List[List[str]]):
3031
----------
3132
dsk: Dict
3233
Dask graph, represented as a task DAG dictionary.
33-
keys: List[List[str]]
34-
2d-list of Dask graph keys whose values we wish to compute and return.
34+
keys: Union[List[List[str]], List[str]]
35+
1d or 2d list of Dask graph keys whose values we wish to compute and return.
3536
3637
Returns
3738
-------
3839
Object
39-
Computed values corresponding to the provided keys.
40+
Computed values corresponding to the provided keys with same dimension.
4041
"""
41-
res = reduce(mars_dask_get(dsk, keys)).execute().fetch()
42-
if not isinstance(res, List):
43-
return [[res]]
44-
else:
45-
return res
4642

43+
if isinstance(keys, List) and not isinstance(keys[0], List): # 1d keys
44+
task = execute(mars_dask_get(dsk, keys))
45+
if not isinstance(task, List):
46+
task = [task]
47+
return map(lambda x: x.fetch(), task)
48+
else: # 2d keys
49+
res = execute(reduce(mars_dask_get(dsk, keys))).fetch()
50+
if not isinstance(res, List):
51+
return [[res]]
52+
else:
53+
return res
4754

48-
def mars_dask_get(dsk: dict, keys: List[List]):
55+
56+
def mars_dask_get(dsk: dict, keys: Union[List[List[str]], List[str]]):
4957
"""
5058
A Dask-Mars convert function. This function will send the dask graph layers
5159
to Mars Remote API, generating mars objects correspond to the provided keys.
@@ -54,21 +62,21 @@ def mars_dask_get(dsk: dict, keys: List[List]):
5462
----------
5563
dsk: Dict
5664
Dask graph, represented as a task DAG dictionary.
57-
keys: List[List[str]]
58-
2d-list of Dask graph keys whose values we wish to compute and return.
65+
keys: Union[List[List[str]], List[str]]
66+
1d or 2d list of Dask graph keys whose values we wish to compute and return.
5967
6068
Returns
6169
-------
6270
Object
63-
Spawned mars objects corresponding to the provided keys.
71+
Spawned mars objects corresponding to the provided keys with same dimension.
6472
"""
6573

6674
def _get_arg(a):
6775
# if arg contains layer index or callable objs, handle it
6876
if ishashable(a) and a in dsk.keys():
6977
while ishashable(a) and a in dsk.keys():
7078
a = dsk[a]
71-
return _execute_task(a)
79+
return _spawn_task(a)
7280
elif not isinstance(a, str) and hasattr(a, "__getitem__"):
7381
if istask(
7482
a
@@ -80,9 +88,14 @@ def _get_arg(a):
8088
return type(a)(_get_arg(i) for i in a)
8189
return a
8290

83-
def _execute_task(task: tuple):
91+
def _spawn_task(task: tuple):
8492
if not istask(task):
8593
return _get_arg(task)
8694
return spawn(task[0], args=tuple(_get_arg(a) for a in task[1:]))
8795

88-
return [[_execute_task(dsk[k]) for k in keys_d] for keys_d in keys]
96+
return [
97+
[_spawn_task(dsk[k]) for k in keys_d]
98+
if isinstance(keys_d, List)
99+
else _spawn_task(dsk[keys_d])
100+
for keys_d in keys
101+
]

mars/contrib/dask/tests/test_dask.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,41 @@ def inc(x: int):
143143
assert dask.compute(test_obj) == dask.compute(
144144
test_obj, scheduler=mars_scheduler
145145
)
146+
147+
148+
@pytest.mark.skipif(not dask_installed, reason="dask not installed")
149+
def test_persist(setup_cluster):
150+
import dask
151+
152+
def inc(x):
153+
return x + 1
154+
155+
a = dask.delayed(inc)(1)
156+
task_mars_persist = dask.delayed(inc)(a.persist(scheduler=mars_scheduler))
157+
task_dask_persist = dask.delayed(inc)(a.persist())
158+
159+
assert task_dask_persist.compute() == task_mars_persist.compute(
160+
scheduler=mars_scheduler
161+
)
162+
163+
164+
@pytest.mark.skipif(not dask_installed, reason="dask not installed")
165+
def test_partitioned_dataframe_persist(setup_cluster):
166+
import numpy as np
167+
import pandas as pd
168+
from dask import dataframe as dd
169+
from pandas._testing import assert_frame_equal
170+
171+
data = np.random.randn(10000, 100)
172+
df = dd.from_pandas(
173+
pd.DataFrame(data, columns=[f"col{i}" for i in range(100)]), npartitions=4
174+
)
175+
df["col0"] = df["col0"] + df["col1"] / 2
176+
col2_mean = df["col2"].mean()
177+
178+
df_mars_persist = df[df["col2"] > col2_mean.persist(scheduler=mars_scheduler)]
179+
df_dask_persist = df[df["col2"] > col2_mean.persist()]
180+
181+
assert_frame_equal(
182+
df_dask_persist.compute(), df_mars_persist.compute(scheduler=mars_scheduler)
183+
)

0 commit comments

Comments
 (0)