Skip to content

Commit 6c9c4b0

Browse files
authored
[Ray] Implements get_chunks_result for Ray execution context (#3023)
1 parent ba8a6d9 commit 6c9c4b0

File tree

7 files changed

+135
-36
lines changed

7 files changed

+135
-36
lines changed

.github/workflows/platform-ci.yml

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
fail-fast: false
1919
matrix:
2020
os: [ubuntu-latest]
21-
python-version: [3.8-kubernetes, 3.8-hadoop, 3.8-ray, 3.8-vineyard, 3.8-dask]
21+
python-version: [3.8-kubernetes, 3.8-hadoop, 3.8-ray, 3.8-ray-dag, 3.8-vineyard, 3.8-dask]
2222
include:
2323
- { os: ubuntu-latest, python-version: 3.8-kubernetes, no-common-tests: 1,
2424
no-deploy: 1, with-kubernetes: "with Kubernetes" }
@@ -28,6 +28,8 @@ jobs:
2828
no-deploy: 1, with-vineyard: "with vineyard" }
2929
- { os: ubuntu-latest, python-version: 3.8-ray, no-common-tests: 1,
3030
no-deploy: 1, with-ray: "with ray" }
31+
- { os: ubuntu-latest, python-version: 3.8-ray-dag, no-common-tests: 1,
32+
no-deploy: 1, with-ray-dag: "with ray dag" }
3133
- { os: ubuntu-latest, python-version: 3.8-dask, no-common-tests: 1,
3234
no-deploy: 1, run-dask: "run dask" }
3335

@@ -51,6 +53,7 @@ jobs:
5153
WITH_KUBERNETES: ${{ matrix.with-kubernetes }}
5254
WITH_VINEYARD: ${{ matrix.with-vineyard }}
5355
WITH_RAY: ${{ matrix.with-ray }}
56+
WITH_RAY_DAG: ${{ matrix.with-ray-dag }}
5457
RUN_DASK: ${{ matrix.run-dask }}
5558
NO_COMMON_TESTS: ${{ matrix.no-common-tests }}
5659
shell: bash
@@ -67,7 +70,7 @@ jobs:
6770
if [[ $UNAME == "windows" ]]; then
6871
pip install virtualenv flaky
6972
else
70-
pip install virtualenv flaky ray
73+
pip install virtualenv flaky
7174
if [ -n "$WITH_KUBERNETES" ]; then
7275
./.github/workflows/install-minikube.sh
7376
pip install kubernetes
@@ -90,7 +93,7 @@ jobs:
9093
sudo mv /tmp/etcd-download-test/etcdctl /usr/local/bin/
9194
rm -fr /tmp/etcd-$ETCD_VER-linux-amd64.tar.gz /tmp/etcd-download-test
9295
fi
93-
if [ -n "$WITH_RAY" ]; then
96+
if [ -n "$WITH_RAY" ] || [ -n "$WITH_RAY_DAG" ]; then
9497
pip install ray[default]==1.9.2
9598
pip install "xgboost_ray==0.1.5" "xgboost<1.6.0"
9699
fi
@@ -107,6 +110,7 @@ jobs:
107110
WITH_CYTHON: ${{ matrix.with-cython }}
108111
WITH_VINEYARD: ${{ matrix.with-vineyard }}
109112
WITH_RAY: ${{ matrix.with-ray }}
113+
WITH_RAY_DAG: ${{ matrix.with-ray-dag }}
110114
RUN_DASK: ${{ matrix.run-dask }}
111115
NO_COMMON_TESTS: ${{ matrix.no-common-tests }}
112116
NUMPY_EXPERIMENTAL_ARRAY_FUNCTION: 1
@@ -143,6 +147,11 @@ jobs:
143147
pytest $PYTEST_CONFIG --durations=0 --timeout=600 -v -s -m ray
144148
coverage report
145149
fi
150+
if [ -n "$WITH_RAY_DAG" ]; then
151+
export MARS_CI_BACKEND=ray
152+
pytest $PYTEST_CONFIG --durations=0 --timeout=600 -v -s -m ray_dag
153+
coverage report
154+
fi
146155
if [ -n "$RUN_DASK" ]; then
147156
pytest $PYTEST_CONFIG mars/contrib/dask/tests/test_dask.py
148157
coverage report

mars/conftest.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from mars.utils import lazy_import
3030

3131
ray = lazy_import("ray")
32+
MARS_CI_BACKEND = os.environ.get("MARS_CI_BACKEND", "mars")
3233

3334

3435
@pytest.fixture(scope="module")
@@ -167,7 +168,11 @@ def _new_test_session(_stop_isolation):
167168
from .deploy.oscar.tests.session import new_test_session
168169

169170
sess = new_test_session(
170-
address="test://127.0.0.1", init_local=True, default=True, timeout=300
171+
address="test://127.0.0.1",
172+
backend=MARS_CI_BACKEND,
173+
init_local=True,
174+
default=True,
175+
timeout=300,
171176
)
172177
with option_context({"show_progress": False}):
173178
try:
@@ -181,7 +186,12 @@ def _new_integrated_test_session(_stop_isolation):
181186
from .deploy.oscar.tests.session import new_test_session
182187

183188
sess = new_test_session(
184-
address="127.0.0.1", init_local=True, n_worker=2, default=True, timeout=300
189+
address="127.0.0.1",
190+
backend=MARS_CI_BACKEND,
191+
init_local=True,
192+
n_worker=2,
193+
default=True,
194+
timeout=300,
185195
)
186196
with option_context({"show_progress": False}):
187197
try:
@@ -213,6 +223,7 @@ def _new_gpu_test_session(_stop_isolation): # pragma: no cover
213223

214224
sess = new_test_session(
215225
address="127.0.0.1",
226+
backend=MARS_CI_BACKEND,
216227
init_local=True,
217228
n_worker=1,
218229
n_cpu=1,

mars/dataframe/base/tests/test_base_execution.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,7 @@ def test_isin_execution(setup):
774774
pd.testing.assert_frame_equal(result, expected)
775775

776776

777+
@pytest.mark.ray_dag
777778
def test_cut_execution(setup):
778779
session = setup
779780

mars/services/task/execution/ray/context.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414

1515
import functools
1616
import inspect
17-
from typing import Union
17+
import logging
18+
from typing import Union, Dict, List
1819

1920
from .....core.context import Context
2021
from .....utils import implements, lazy_import
2122
from ....context import ThreadedServiceContext
2223

2324
ray = lazy_import("ray")
25+
logger = logging.getLogger(__name__)
2426

2527

2628
class RayRemoteObjectManager:
@@ -89,7 +91,14 @@ def _get_task_state_actor(self) -> "ray.actor.ActorHandle":
8991
@implements(Context.create_remote_object)
9092
def create_remote_object(self, name: str, object_cls, *args, **kwargs):
9193
task_state_actor = self._get_task_state_actor()
92-
task_state_actor.create_remote_object.remote(name, object_cls, *args, **kwargs)
94+
r = task_state_actor.create_remote_object.remote(
95+
name, object_cls, *args, **kwargs
96+
)
97+
# Make sure the actor is created. The remote object may not be created
98+
# when get_remote_object from worker because the callers of
99+
# create_remote_object and get_remote_object are not in the same worker.
100+
# Use sync Ray actor requires this `ray.get`, too.
101+
ray.get(r)
93102
return _RayRemoteObjectWrapper(task_state_actor, name)
94103

95104
@implements(Context.get_remote_object)
@@ -107,7 +116,17 @@ def destroy_remote_object(self, name: str):
107116
class RayExecutionContext(_RayRemoteObjectContext, ThreadedServiceContext):
108117
"""The context for tiling."""
109118

110-
pass
119+
def __init__(self, task_context: Dict, *args, **kwargs):
120+
super().__init__(*args, **kwargs)
121+
self._task_context = task_context
122+
123+
@implements(Context.get_chunks_result)
124+
def get_chunks_result(self, data_keys: List[str]) -> List:
125+
logger.info("Getting %s chunks result.", len(data_keys))
126+
object_refs = [self._task_context[key] for key in data_keys]
127+
result = ray.get(object_refs)
128+
logger.info("Got %s chunks result.", len(result))
129+
return result
111130

112131

113132
# TODO(fyrestone): Implement more APIs for Ray.

mars/services/task/execution/ray/executor.py

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import asyncio
16+
import functools
1617
import logging
1718
from typing import List, Dict, Any, Set
1819
from .....core import ChunkGraph, Chunk, TileContext
@@ -123,22 +124,22 @@ def __init__(
123124
config: ExecutionConfig,
124125
task: Task,
125126
tile_context: TileContext,
126-
ray_executor: "ray.remote_function.RemoteFunction",
127+
task_context: Dict[str, "ray.ObjectRef"],
127128
task_state_actor: "ray.actor.ActorHandle",
128129
lifecycle_api: LifecycleAPI,
129130
meta_api: MetaAPI,
130131
):
131132
self._config = config
132133
self._task = task
133134
self._tile_context = tile_context
134-
self._ray_executor = ray_executor
135+
self._task_context = task_context
135136
self._task_state_actor = task_state_actor
137+
self._ray_executor = self._get_ray_executor()
136138

137139
# api
138140
self._lifecycle_api = lifecycle_api
139141
self._meta_api = meta_api
140142

141-
self._task_context = {}
142143
self._available_band_resources = None
143144

144145
# For progress
@@ -158,19 +159,19 @@ async def create(
158159
tile_context: TileContext,
159160
**kwargs,
160161
) -> "TaskExecutor":
161-
ray_executor = ray.remote(execute_subtask)
162162
lifecycle_api, meta_api = await cls._get_apis(session_id, address)
163163
task_state_actor = (
164164
ray.remote(RayTaskState)
165165
.options(name=RayTaskState.gen_name(task.task_id))
166166
.remote()
167167
)
168-
await cls._init_context(task_state_actor, session_id, address)
168+
task_context = {}
169+
await cls._init_context(task_context, task_state_actor, session_id, address)
169170
return cls(
170171
config,
171172
task,
172173
tile_context,
173-
ray_executor,
174+
task_context,
174175
task_state_actor,
175176
lifecycle_api,
176177
meta_api,
@@ -184,13 +185,29 @@ async def _get_apis(cls, session_id: str, address: str):
184185
MetaAPI.create(session_id, address),
185186
)
186187

188+
@staticmethod
189+
@functools.lru_cache(maxsize=1)
190+
def _get_ray_executor():
191+
# Export remote function once.
192+
return ray.remote(execute_subtask)
193+
187194
@classmethod
188195
async def _init_context(
189-
cls, task_state_actor: "ray.actor.ActorHandle", session_id: str, address: str
196+
cls,
197+
task_context: Dict[str, "ray.ObjectRef"],
198+
task_state_actor: "ray.actor.ActorHandle",
199+
session_id: str,
200+
address: str,
190201
):
191202
loop = asyncio.get_running_loop()
192203
context = RayExecutionContext(
193-
task_state_actor, session_id, address, address, address, loop=loop
204+
task_context,
205+
task_state_actor,
206+
session_id,
207+
address,
208+
address,
209+
address,
210+
loop=loop,
194211
)
195212
await context.init()
196213
set_context(context)
@@ -204,7 +221,7 @@ async def execute_subtask_graph(
204221
context: Any = None,
205222
) -> Dict[Chunk, ExecutionChunkResult]:
206223
logger.info("Stage %s start.", stage_id)
207-
context = self._task_context
224+
task_context = self._task_context
208225
output_meta_object_refs = []
209226
self._pre_all_stages_tile_progress = (
210227
self._pre_all_stages_tile_progress + self._cur_stage_tile_progress
@@ -221,7 +238,7 @@ async def execute_subtask_graph(
221238
for subtask in subtask_graph.topological_iter():
222239
subtask_chunk_graph = subtask.chunk_graph
223240
key_to_input = await self._load_subtask_inputs(
224-
stage_id, subtask, subtask_chunk_graph, context
241+
stage_id, subtask, subtask_chunk_graph, task_context
225242
)
226243
output_keys = self._get_subtask_output_keys(subtask_chunk_graph)
227244
output_meta_keys = result_meta_keys & output_keys
@@ -245,32 +262,34 @@ async def execute_subtask_graph(
245262
meta_object_ref, *output_object_refs = output_object_refs
246263
# TODO(fyrestone): Fetch(not get) meta object here.
247264
output_meta_object_refs.append(meta_object_ref)
248-
context.update(zip(output_keys, output_object_refs))
265+
task_context.update(zip(output_keys, output_object_refs))
249266
logger.info("Submitted %s subtasks of stage %s.", len(subtask_graph), stage_id)
250267

251268
key_to_meta = {}
252269
if len(output_meta_object_refs) > 0:
253270
# TODO(fyrestone): Optimize update meta by fetching partial meta.
271+
meta_count = len(output_meta_object_refs)
272+
logger.info("Getting %s metas of stage %s.", meta_count, stage_id)
254273
meta_list = await asyncio.gather(*output_meta_object_refs)
255274
for meta in meta_list:
256275
key_to_meta.update(meta)
257276
assert len(key_to_meta) == len(result_meta_keys)
258-
logger.info(
259-
"Got %s metas of stage %s.", len(output_meta_object_refs), stage_id
260-
)
277+
logger.info("Got %s metas of stage %s.", meta_count, stage_id)
261278

262279
chunk_to_meta = {}
263-
output_object_refs = []
280+
# ray.wait requires the object ref list is unique.
281+
output_object_refs = set()
264282
for chunk in chunk_graph.result_chunks:
265283
chunk_key = chunk.key
266-
object_ref = context[chunk_key]
267-
output_object_refs.append(object_ref)
284+
object_ref = task_context[chunk_key]
285+
output_object_refs.add(object_ref)
268286
chunk_meta = key_to_meta.get(chunk_key)
269287
if chunk_meta is not None:
270288
chunk_to_meta[chunk] = ExecutionChunkResult(chunk_meta, object_ref)
271289

272290
logger.info("Waiting for stage %s complete.", stage_id)
273-
ray.wait(output_object_refs, fetch_local=False)
291+
# Patched the asyncio.to_thread for Python < 3.9 at mars/lib/aio/__init__.py
292+
await asyncio.to_thread(ray.wait, list(output_object_refs), fetch_local=False)
274293
# Just use `self._cur_stage_tile_progress` as current stage progress
275294
# because current stage is finished, its progress is 1.
276295
self._pre_all_stages_progress += self._cur_stage_tile_progress
@@ -289,14 +308,20 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
289308
chunk_keys = []
290309
for chunk in self._tile_context[tileable].chunks:
291310
chunk_keys.append(chunk.key)
292-
object_ref = self._task_context[chunk.key]
293-
update_metas.append(
294-
self._meta_api.set_chunk_meta.delay(
295-
chunk,
296-
bands=[],
297-
object_ref=object_ref,
311+
if chunk.key in self._task_context:
312+
# Some tileable graph may have result chunks that not be executed,
313+
# for example:
314+
# r, b = cut(series, bins, retbins=True)
315+
# r_result = r.execute().fetch()
316+
# b_result = b.execute().fetch() <- This is the case
317+
object_ref = self._task_context[chunk.key]
318+
update_metas.append(
319+
self._meta_api.set_chunk_meta.delay(
320+
chunk,
321+
bands=[],
322+
object_ref=object_ref,
323+
)
298324
)
299-
)
300325
update_lifecycles.append(
301326
self._lifecycle_api.track.delay(tileable.key, chunk_keys)
302327
)
@@ -325,7 +350,7 @@ async def get_progress(self) -> float:
325350
finished_objects, _ = ray.wait(
326351
self._cur_stage_output_object_refs,
327352
num_returns=total,
328-
timeout=0.1,
353+
timeout=0, # Avoid blocking the asyncio loop.
329354
fetch_local=False,
330355
)
331356
stage_progress = (

mars/services/task/execution/ray/tests/test_ray_execution_backend.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,15 @@
2020

2121
from ......core.graph import TileableGraph, TileableGraphBuilder, ChunkGraphBuilder
2222
from ......serialization import serialize
23-
from ......tests.core import require_ray
23+
from ......tests.core import require_ray, mock
2424
from ......utils import lazy_import, get_chunk_params
25+
from .....context import ThreadedServiceContext
2526
from ....core import new_task_id
26-
from ..context import RayRemoteObjectManager, _RayRemoteObjectContext
27+
from ..context import (
28+
RayExecutionContext,
29+
RayRemoteObjectManager,
30+
_RayRemoteObjectContext,
31+
)
2732
from ..executor import execute_subtask
2833
from ..fetcher import RayFetcher
2934

@@ -119,3 +124,27 @@ async def bar(self, a, b):
119124
context.destroy_remote_object(name)
120125
with pytest.raises(KeyError):
121126
remote_object.foo(3, 4)
127+
128+
class MyException(Exception):
129+
pass
130+
131+
class _ErrorRemoteObject:
132+
def __init__(self):
133+
raise MyException()
134+
135+
with pytest.raises(MyException):
136+
context.create_remote_object(name, _ErrorRemoteObject)
137+
138+
139+
@require_ray
140+
def test_get_chunks_result(ray_start_regular_shared2):
141+
value = 123
142+
o = ray.put(value)
143+
144+
def fake_init(self):
145+
pass
146+
147+
with mock.patch.object(ThreadedServiceContext, "__init__", new=fake_init):
148+
context = RayExecutionContext({"abc": o}, None)
149+
r = context.get_chunks_result(["abc"])
150+
assert r == [value]

0 commit comments

Comments
 (0)