Skip to content

Commit 61c0c51

Browse files
authored
[Ray] Support basic subtask retry and lineage reconstruction (#2969)
1 parent aa23fd0 commit 61c0c51

File tree

11 files changed

+238
-46
lines changed

11 files changed

+238
-46
lines changed

mars/conftest.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,8 @@ def _ray_large_cluster(request): # pragma: no cover
9696
param = getattr(request, "param", {})
9797
num_nodes = param.get("num_nodes", 3)
9898
num_cpus = param.get("num_cpus", 16)
99-
try:
100-
from ray.cluster_utils import Cluster
101-
except ModuleNotFoundError:
102-
from ray._private.cluster_utils import Cluster
99+
from ray.cluster_utils import Cluster
100+
103101
cluster = Cluster()
104102
remote_nodes = []
105103
for i in range(num_nodes):
@@ -114,11 +112,14 @@ def _ray_large_cluster(request): # pragma: no cover
114112
except TypeError:
115113
job_config = None
116114
ray.init(address=cluster.address, job_config=job_config)
117-
register_ray_serializers()
115+
use_ray_serialization = param.get("use_ray_serialization", True)
116+
if use_ray_serialization:
117+
register_ray_serializers()
118118
try:
119-
yield
119+
yield cluster
120120
finally:
121-
unregister_ray_serializers()
121+
if use_ray_serialization:
122+
unregister_ray_serializers()
122123
Router.set_instance(None)
123124
RayServer.clear()
124125
ray.shutdown()
@@ -158,6 +159,14 @@ async def ray_create_mars_cluster(request):
158159
yield client
159160

160161

162+
@pytest.fixture
163+
def stop_mars():
164+
yield
165+
import mars
166+
167+
mars.stop_server()
168+
169+
161170
@pytest.fixture(scope="module")
162171
def _new_test_session():
163172
from .deploy.oscar.tests.session import new_test_session

mars/dataframe/contrib/raydataset/dataset.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import operator
15+
from functools import reduce
1416

1517
from ....utils import lazy_import
1618
from .mldataset import _rechunk_if_needed
@@ -59,5 +61,6 @@ def __getstate__():
5961

6062

6163
def get_chunk_refs(df):
62-
fetched_infos: Dict[str, List] = df.fetch_infos(fields=["object_id"])
63-
return fetched_infos["object_id"]
64+
fetched_infos: Dict[str, List] = df.fetch_infos(["object_refs"])
65+
object_refs = reduce(operator.concat, fetched_infos["object_refs"])
66+
return object_refs

mars/dataframe/contrib/raydataset/mldataset.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,12 @@ def to_ray_mldataset(df, num_shards: int = None):
104104
# chunk1 for addr1,
105105
# chunk2 & chunk3 for addr2,
106106
# chunk4 for addr1
107-
fetched_infos: Dict[str, List] = df.fetch_infos(fields=["band", "object_id"])
107+
fetched_infos: Dict[str, List] = df.fetch_infos(fields=["bands", "object_refs"])
108108
chunk_addr_refs: List[Tuple[Tuple, "ray.ObjectRef"]] = [
109-
(band, object_id)
110-
for band, object_id in zip(fetched_infos["band"], fetched_infos["object_id"])
109+
(bands[0], object_refs[0])
110+
for bands, object_refs in zip(
111+
fetched_infos["bands"], fetched_infos["object_refs"]
112+
)
111113
]
112114
group_to_obj_refs: Dict[str, List[ray.ObjectRef]] = _group_chunk_refs(
113115
chunk_addr_refs, num_shards

mars/dataframe/contrib/raydataset/tests/test_raydataset.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,12 @@ async def test_convert_to_ray_dataset(
6060
with session:
6161
value = np.random.rand(10, 10)
6262
chunk_size, num_shards = test_option
63-
df: md.DataFrame = md.DataFrame(value, chunk_size=chunk_size)
63+
# ray dataset needs str columns
64+
df: md.DataFrame = md.DataFrame(
65+
value,
66+
chunk_size=chunk_size,
67+
columns=[f"c{i}" for i in range(value.shape[1])],
68+
)
6469
df.execute()
6570

6671
ds = mdd.to_ray_dataset(df, num_shards=num_shards)
@@ -162,7 +167,8 @@ async def test_mars_with_xgboost_sklearn_reg(ray_start_regular_shared, create_cl
162167
session = new_session(address=create_cluster.address, default=True)
163168
with session:
164169
np_X, np_y = make_regression(n_samples=1_0000, n_features=10)
165-
X, y = md.DataFrame(np_X), md.DataFrame({"target": np_y})
170+
columns = [f"c{i}" for i in range(np_X.shape[1])]
171+
X, y = md.DataFrame(np_X, columns=columns), md.DataFrame({"target": np_y})
166172
df: md.DataFrame = md.concat([md.DataFrame(X), md.DataFrame(y)], axis=1)
167173
df.execute()
168174

@@ -172,10 +178,10 @@ async def test_mars_with_xgboost_sklearn_reg(ray_start_regular_shared, create_cl
172178

173179
import gc
174180

175-
gc.collect() # Ensure MLDataset does hold mars dataframe to avoid gc.
181+
gc.collect() # Ensure Dataset does hold mars dataframe to avoid gc.
176182
ray_params = RayParams(num_actors=2, cpus_per_actor=1)
177183
reg = RayXGBRegressor(ray_params=ray_params, random_state=42)
178184
# train
179185
reg.fit(RayDMatrix(ds, "target"), y=None, ray_params=ray_params)
180186
reg.predict(RayDMatrix(ds, "target"))
181-
reg.predict(pd.DataFrame(np_X))
187+
reg.predict(pd.DataFrame(np_X, columns=columns))

mars/deploy/oscar/session.py

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,7 +1161,14 @@ async def fetch(self, *tileables, **kwargs) -> list:
11611161
return result
11621162

11631163
async def fetch_infos(self, *tileables, fields, **kwargs) -> list:
1164-
available_fields = {"object_id", "level", "memory_size", "store_size", "band"}
1164+
available_fields = {
1165+
"object_id",
1166+
"object_refs",
1167+
"level",
1168+
"memory_size",
1169+
"store_size",
1170+
"bands",
1171+
}
11651172
if fields is None:
11661173
fields = available_fields
11671174
else:
@@ -1175,34 +1182,22 @@ async def fetch_infos(self, *tileables, fields, **kwargs) -> list:
11751182
if kwargs: # pragma: no cover
11761183
unexpected_keys = ", ".join(list(kwargs.keys()))
11771184
raise TypeError(f"`fetch` got unexpected arguments: {unexpected_keys}")
1178-
1185+
# following fields needs to access storage API to get the meta.
1186+
_need_query_storage_fields = {"level", "memory_size", "store_size"}
1187+
_need_query_storage = bool(_need_query_storage_fields & fields)
11791188
with enter_mode(build=True):
1180-
chunks = []
1181-
get_chunk_metas = []
1182-
fetch_infos_list = []
1183-
for tileable in tileables:
1184-
fetch_tileable, _ = self._get_to_fetch_tileable(tileable)
1185-
fetch_infos = []
1186-
for chunk in fetch_tileable.chunks:
1187-
chunks.append(chunk)
1188-
get_chunk_metas.append(
1189-
self._meta_api.get_chunk_meta.delay(chunk.key, fields=["bands"])
1190-
)
1191-
fetch_infos.append(
1192-
ChunkFetchInfo(tileable=tileable, chunk=chunk, indexes=None)
1193-
)
1194-
fetch_infos_list.append(fetch_infos)
1195-
chunk_metas = await self._meta_api.get_chunk_meta.batch(*get_chunk_metas)
1196-
chunk_to_band = {
1197-
chunk: meta["bands"][0] for chunk, meta in zip(chunks, chunk_metas)
1198-
}
1199-
1189+
chunk_to_bands, fetch_infos_list, result = await self._query_meta_service(
1190+
tileables, fields, _need_query_storage
1191+
)
1192+
if not _need_query_storage:
1193+
assert result is not None
1194+
return result
12001195
storage_api_to_gets = defaultdict(list)
12011196
storage_api_to_fetch_infos = defaultdict(list)
12021197
for fetch_info in itertools.chain(*fetch_infos_list):
12031198
chunk = fetch_info.chunk
1204-
band = chunk_to_band[chunk]
1205-
storage_api = await self._get_storage_api(band)
1199+
bands = chunk_to_bands[chunk]
1200+
storage_api = await self._get_storage_api(bands[0])
12061201
storage_api_to_gets[storage_api].append(
12071202
storage_api.get_infos.delay(chunk.key)
12081203
)
@@ -1219,7 +1214,7 @@ async def fetch_infos(self, *tileables, fields, **kwargs) -> list:
12191214
for fetch_infos in fetch_infos_list:
12201215
fetched = defaultdict(list)
12211216
for fetch_info in fetch_infos:
1222-
band = chunk_to_band[fetch_info.chunk]
1217+
bands = chunk_to_bands[fetch_info.chunk]
12231218
# Currently there's only one item in the returned List from storage_api.get_infos()
12241219
data = fetch_info.data[0]
12251220
if "object_id" in fields:
@@ -1232,12 +1227,47 @@ async def fetch_infos(self, *tileables, fields, **kwargs) -> list:
12321227
fetched["store_size"].append(data.store_size)
12331228
# data.band misses ip info, e.g. 'numa-0'
12341229
# while band doesn't, e.g. (address0, 'numa-0')
1235-
if "band" in fields:
1236-
fetched["band"].append(band)
1230+
if "bands" in fields:
1231+
fetched["bands"].append(bands)
12371232
result.append(fetched)
12381233

12391234
return result
12401235

1236+
async def _query_meta_service(self, tileables, fields, query_storage):
1237+
chunks = []
1238+
get_chunk_metas = []
1239+
fetch_infos_list = []
1240+
for tileable in tileables:
1241+
fetch_tileable, _ = self._get_to_fetch_tileable(tileable)
1242+
fetch_infos = []
1243+
for chunk in fetch_tileable.chunks:
1244+
chunks.append(chunk)
1245+
get_chunk_metas.append(
1246+
self._meta_api.get_chunk_meta.delay(
1247+
chunk.key,
1248+
fields=["bands"] if query_storage else fields,
1249+
)
1250+
)
1251+
fetch_infos.append(
1252+
ChunkFetchInfo(tileable=tileable, chunk=chunk, indexes=None)
1253+
)
1254+
fetch_infos_list.append(fetch_infos)
1255+
chunk_metas = await self._meta_api.get_chunk_meta.batch(*get_chunk_metas)
1256+
if not query_storage:
1257+
result = []
1258+
chunk_to_meta = dict(zip(chunks, chunk_metas))
1259+
for fetch_infos in fetch_infos_list:
1260+
fetched = defaultdict(list)
1261+
for fetch_info in fetch_infos:
1262+
for field in fields:
1263+
fetched[field].append(chunk_to_meta[fetch_info.chunk][field])
1264+
result.append(fetched)
1265+
return {}, fetch_infos_list, result
1266+
chunk_to_bands = {
1267+
chunk: meta["bands"] for chunk, meta in zip(chunks, chunk_metas)
1268+
}
1269+
return chunk_to_bands, fetch_infos_list, None
1270+
12411271
async def decref(self, *tileable_keys):
12421272
logger.debug("Decref tileables on client: %s", tileable_keys)
12431273
return await self._lifecycle_api.decref_tileables(list(tileable_keys))

mars/deploy/oscar/tests/test_local.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,12 @@ async def test_fetch_infos(create_cluster):
381381
assert "level" in fetched_infos
382382
assert "memory_size" in fetched_infos
383383
assert "store_size" in fetched_infos
384-
assert "band" in fetched_infos
384+
assert "bands" in fetched_infos
385+
386+
fetched_infos = df.fetch_infos(fields=["object_id", "bands"])
387+
assert "object_id" in fetched_infos
388+
assert "bands" in fetched_infos
389+
assert len(fetched_infos) == 2
385390

386391
fetch_infos((df, df), fields=None)
387392
results_infos = mr.ExecutableTuple([df, df]).execute()._fetch_infos()
@@ -390,7 +395,7 @@ async def test_fetch_infos(create_cluster):
390395
assert "level" in results_infos[0]
391396
assert "memory_size" in results_infos[0]
392397
assert "store_size" in results_infos[0]
393-
assert "band" in results_infos[0]
398+
assert "bands" in results_infos[0]
394399

395400

396401
async def _run_web_session_test(web_address):

mars/deploy/oscar/tests/test_ray.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414

1515
import asyncio
1616
import copy
17+
import operator
1718
import os
1819
import subprocess
1920
import sys
2021
import tempfile
2122
import threading
2223
import time
24+
from functools import reduce
2325

2426
import numpy as np
2527
import pandas as pd
@@ -154,6 +156,12 @@ async def test_execute_describe(ray_start_regular, create_cluster):
154156
@pytest.mark.asyncio
155157
async def test_fetch_infos(ray_start_regular, create_cluster):
156158
await test_local.test_fetch_infos(create_cluster)
159+
df = md.DataFrame(mt.random.RandomState(0).rand(5000, 1, chunk_size=1000))
160+
df.execute()
161+
fetched_infos = df.fetch_infos(fields=["object_refs"])
162+
object_refs = reduce(operator.concat, fetched_infos["object_refs"])
163+
assert len(fetched_infos) == 1
164+
assert len(object_refs) == 5
157165

158166

159167
@require_ray
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright 1999-2021 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import operator
17+
from functools import reduce
18+
19+
import pandas as pd
20+
import pytest
21+
22+
import mars
23+
from .... import dataframe as md
24+
from .... import tensor as mt
25+
from ....tests.core import require_ray
26+
from ....utils import lazy_import
27+
28+
ray = lazy_import("ray")
29+
30+
31+
@require_ray
32+
@pytest.mark.parametrize(
33+
"ray_large_cluster",
34+
[{"num_nodes": 0, "use_ray_serialization": False}],
35+
indirect=True,
36+
)
37+
@pytest.mark.parametrize("reconstruction_enabled", [True, False])
38+
def test_basic_object_reconstruction(
39+
ray_large_cluster, reconstruction_enabled, stop_mars
40+
):
41+
config = {
42+
"num_heartbeats_timeout": 10,
43+
"raylet_heartbeat_period_milliseconds": 200,
44+
"object_timeout_milliseconds": 200,
45+
}
46+
# Workaround to reset the config to the default value.
47+
if not reconstruction_enabled:
48+
config["lineage_pinning_enabled"] = False
49+
subtask_max_retries = 0
50+
else:
51+
subtask_max_retries = 1
52+
53+
cluster = ray_large_cluster
54+
# Head node with no resources.
55+
cluster.add_node(
56+
num_cpus=0,
57+
_system_config=config,
58+
enable_object_reconstruction=reconstruction_enabled,
59+
)
60+
ray.init(address=cluster.address)
61+
# Node to place the initial object.
62+
node_to_kill = cluster.add_node(num_cpus=1, object_store_memory=10**8)
63+
mars.new_session(
64+
backend="ray",
65+
config={"scheduling.subtask_max_retries": subtask_max_retries},
66+
default=True,
67+
)
68+
cluster.wait_for_nodes()
69+
70+
df = md.DataFrame(mt.random.RandomState(0).rand(2_000_000, 1, chunk_size=1_000_000))
71+
df.execute()
72+
# this will submit new ray tasks
73+
df2 = df.map_chunk(lambda pdf: pdf * 2).execute()
74+
executed_infos = df2.fetch_infos(fields=["object_refs"])
75+
object_refs = reduce(operator.concat, executed_infos["object_refs"])
76+
head5 = df2.head(5).to_pandas()
77+
78+
cluster.remove_node(node_to_kill, allow_graceful=False)
79+
node_to_kill = cluster.add_node(num_cpus=1, object_store_memory=10**8)
80+
81+
# use a dependent_task to avoid fetch lost objects to local
82+
@ray.remote
83+
def dependent_task(x):
84+
return x
85+
86+
if reconstruction_enabled:
87+
ray.get([dependent_task.remote(ref) for ref in object_refs])
88+
new_head5 = df2.head(5).to_pandas()
89+
pd.testing.assert_frame_equal(head5, new_head5)
90+
else:
91+
with pytest.raises(ray.exceptions.RayTaskError):
92+
df2.head(5).to_pandas()
93+
with pytest.raises(ray.exceptions.ObjectLostError):
94+
ray.get(object_refs)
95+
96+
# Losing the object a second time will cause reconstruction to fail because
97+
# we have reached the max task retries.
98+
cluster.remove_node(node_to_kill, allow_graceful=False)
99+
cluster.add_node(num_cpus=1, object_store_memory=10**8)
100+
101+
if reconstruction_enabled:
102+
with pytest.raises(
103+
ray.exceptions.ObjectReconstructionFailedMaxAttemptsExceededError
104+
):
105+
ray.get(object_refs)
106+
else:
107+
with pytest.raises(ray.exceptions.ObjectLostError):
108+
ray.get(object_refs)

0 commit comments

Comments
 (0)