Skip to content

Commit 7b24842

Browse files
authored
[Ray] Implement ray task executor progress (#3008)
1 parent d23b588 commit 7b24842

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

mars/deploy/oscar/tests/test_ray_dag.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,9 @@ async def test_iterative_tiling(ray_start_regular_shared2, create_cluster):
116116
@pytest.mark.parametrize("config", [{"backend": "ray"}])
117117
def test_sync_execute(config):
118118
test_local.test_sync_execute(config)
119+
120+
121+
@require_ray
122+
@pytest.mark.asyncio
123+
async def test_session_progress(ray_start_regular_shared2, create_cluster):
124+
test_local.test_session_progress(create_cluster)

mars/services/task/execution/ray/executor.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,12 @@ def __init__(
131131
self._task_context = {}
132132
self._available_band_resources = None
133133

134+
# For progress
135+
self._pre_all_stages_progress = 0.0
136+
self._pre_all_stages_tile_progress = 0
137+
self._cur_stage_tile_progress = 0
138+
self._cur_stage_output_object_refs = []
139+
134140
@classmethod
135141
async def create(
136142
cls,
@@ -190,7 +196,12 @@ async def execute_subtask_graph(
190196
logger.info("Stage %s start.", stage_id)
191197
context = self._task_context
192198
output_meta_object_refs = []
193-
199+
self._pre_all_stages_tile_progress = (
200+
self._pre_all_stages_tile_progress + self._cur_stage_tile_progress
201+
)
202+
self._cur_stage_tile_progress = (
203+
self._tile_context.get_all_progress() - self._pre_all_stages_tile_progress
204+
)
194205
logger.info("Submitting %s subtasks of stage %s.", len(subtask_graph), stage_id)
195206
result_meta_keys = {
196207
chunk.key
@@ -219,6 +230,7 @@ async def execute_subtask_graph(
219230
continue
220231
elif output_count == 1:
221232
output_object_refs = [output_object_refs]
233+
self._cur_stage_output_object_refs.extend(output_object_refs)
222234
if output_meta_keys:
223235
meta_object_ref, *output_object_refs = output_object_refs
224236
# TODO(fyrestone): Fetch(not get) meta object here.
@@ -249,6 +261,10 @@ async def execute_subtask_graph(
249261

250262
logger.info("Waiting for stage %s complete.", stage_id)
251263
ray.wait(output_object_refs, fetch_local=False)
264+
# Just use `self._cur_stage_tile_progress` as current stage progress
265+
# because current stage is finished, its progress is 1.
266+
self._pre_all_stages_progress += self._cur_stage_tile_progress
267+
self._cur_stage_output_object_refs.clear()
252268
logger.info("Stage %s is complete.", stage_id)
253269
return chunk_to_meta
254270

@@ -293,7 +309,19 @@ async def get_available_band_resources(self) -> Dict[BandType, Resource]:
293309

294310
async def get_progress(self) -> float:
295311
"""Get the execution progress."""
296-
return 1
312+
stage_progress = 0.0
313+
total = len(self._cur_stage_output_object_refs)
314+
if total > 0:
315+
finished_objects, _ = ray.wait(
316+
self._cur_stage_output_object_refs,
317+
num_returns=total,
318+
timeout=0.1,
319+
fetch_local=False,
320+
)
321+
stage_progress = (
322+
len(finished_objects) / total * self._cur_stage_tile_progress
323+
)
324+
return self._pre_all_stages_progress + stage_progress
297325

298326
async def cancel(self):
299327
"""Cancel execution."""

0 commit comments

Comments
 (0)