Skip to content

Commit ea13a3d

Browse files
authored
Reduce RPC cost of oscar by removing unnecessary tasks (#2597)
1 parent fecfe6c commit ea13a3d

File tree

10 files changed

+120
-83
lines changed

10 files changed

+120
-83
lines changed

mars/deploy/oscar/session.py

Lines changed: 60 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,28 @@ async def init(
816816
else:
817817
return await cls._init(address, session_id, new=new, timeout=timeout)
818818

819+
async def _update_progress(self, task_id: str, progress: Progress):
820+
zero_acc_time = 0
821+
delay = 0.5
822+
while True:
823+
try:
824+
last_progress_value = progress.value
825+
progress.value = await self._task_api.get_task_progress(task_id)
826+
if abs(progress.value - last_progress_value) < 1e-4:
827+
# if percentage does not change, we add delay time by 0.5 seconds every time
828+
zero_acc_time = min(5, zero_acc_time + 0.5)
829+
delay = zero_acc_time
830+
else:
831+
# percentage changes, we use percentage speed to calc progress time
832+
zero_acc_time = 0
833+
speed = abs(progress.value - last_progress_value) / delay
834+
# one percent for one second
835+
delay = 0.01 / speed
836+
delay = max(0.5, min(delay, 5.0))
837+
await asyncio.sleep(delay)
838+
except asyncio.CancelledError:
839+
break
840+
819841
async def _run_in_background(
820842
self,
821843
tileables: list,
@@ -826,55 +848,52 @@ async def _run_in_background(
826848
with enter_mode(build=True, kernel=True):
827849
# wait for task to finish
828850
cancelled = False
851+
progress_task = asyncio.create_task(
852+
self._update_progress(task_id, progress)
853+
)
829854
start_time = time.time()
830-
while True:
831-
try:
832-
if not cancelled:
833-
task_result: Union[TaskResult, None] = None
834-
try:
835-
task_result = await self._task_api.wait_task(
836-
task_id, timeout=0.5
837-
)
838-
finally:
839-
if task_result is None:
840-
# not finished, set progress
841-
progress.value = await self._task_api.get_task_progress(
842-
task_id
843-
)
844-
else:
845-
progress.value = 1.0
846-
if task_result is not None:
847-
break
848-
else:
849-
# wait for task to finish
850-
task_result: TaskResult = await self._task_api.wait_task(
851-
task_id
852-
)
855+
task_result: Optional[TaskResult] = None
856+
try:
857+
if self.timeout is None:
858+
check_interval = 30
859+
else:
860+
elapsed = time.time() - start_time
861+
check_interval = min(self.timeout - elapsed, 30)
862+
863+
while True:
864+
task_result = await self._task_api.wait_task(
865+
task_id, timeout=check_interval
866+
)
867+
if task_result is not None:
853868
break
854-
except asyncio.CancelledError:
855-
# cancelled
856-
cancelled = True
857-
await self._task_api.cancel_task(task_id)
858-
except TimeoutError: # pragma: no cover
859-
# ignore timeout when waiting for subtask progresses
860-
pass
861-
finally:
862-
if (
869+
elif (
863870
self.timeout is not None
864871
and time.time() - start_time > self.timeout
865872
):
866873
raise TimeoutError(
867874
f"Task({task_id}) running time > {self.timeout}"
868875
)
869-
profiling.result = task_result.profiling
870-
if task_result.profiling:
871-
logger.warning(
872-
"Profile task %s execution result:\n%s",
873-
task_id,
874-
json.dumps(task_result.profiling, indent=4),
875-
)
876-
if task_result.error:
877-
raise task_result.error.with_traceback(task_result.traceback)
876+
except asyncio.CancelledError:
877+
# cancelled
878+
cancelled = True
879+
await self._task_api.cancel_task(task_id)
880+
finally:
881+
progress_task.cancel()
882+
if task_result is not None:
883+
progress.value = 1.0
884+
else:
885+
# not finished, set progress
886+
progress.value = await self._task_api.get_task_progress(task_id)
887+
if task_result is not None:
888+
profiling.result = task_result.profiling
889+
if task_result.profiling:
890+
logger.warning(
891+
"Profile task %s execution result:\n%s",
892+
task_id,
893+
json.dumps(task_result.profiling, indent=4),
894+
)
895+
if task_result.error:
896+
raise task_result.error.with_traceback(task_result.traceback)
878897
if cancelled:
879898
return
880899
fetch_tileables = await self._task_api.get_fetch_tileables(task_id)

mars/deploy/oscar/tests/test_local.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ def f1(interval: float, count: int):
436436
break
437437
await asyncio.sleep(0.1)
438438
else:
439-
raise Exception("progress test failed.")
439+
raise Exception(f"progress test failed, actual value {info.progress()}.")
440440

441441
await info
442442
assert info.result() is None

mars/oscar/backends/context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,15 @@ def _process_result_message(message: Union[ResultMessage, ErrorMessage]):
7171

7272
async def _wait(self, future: asyncio.Future, address: str, message: _MessageBase):
7373
try:
74-
await asyncio.wait([future])
74+
await asyncio.shield(future)
7575
except asyncio.CancelledError:
7676
try:
7777
await self.cancel(address, message.message_id)
7878
except CannotCancelTask:
7979
# cancel failed, already finished
8080
raise asyncio.CancelledError
81+
except: # noqa: E722 # nosec # pylint: disable=bare-except
82+
pass
8183
return await future
8284

8385
async def create_actor(

mars/oscar/backends/mars/pool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ async def kill_sub_pool(
240240
await asyncio.to_thread(process.join, 5)
241241

242242
async def is_sub_pool_alive(self, process: multiprocessing.Process):
243-
return process.is_alive()
243+
return await asyncio.to_thread(process.is_alive)
244244

245245
async def recover_sub_pool(self, address: str):
246246
process_index = self._config.get_process_index(address)

mars/oscar/backends/pool.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import asyncio
1616
import concurrent.futures as futures
17-
import contextlib
1817
import itertools
1918
import logging
2019
import os
@@ -314,12 +313,10 @@ async def handle_control_command(
314313

315314
return processor.result
316315

317-
@contextlib.contextmanager
318-
def _run_coro(self, message_id: bytes, coro: Coroutine):
319-
future = asyncio.create_task(coro)
320-
self._process_messages[message_id] = future
316+
async def _run_coro(self, message_id: bytes, coro: Coroutine):
317+
self._process_messages[message_id] = asyncio.tasks.current_task()
321318
try:
322-
yield future
319+
return await coro
323320
finally:
324321
self._process_messages.pop(message_id, None)
325322

@@ -332,10 +329,9 @@ async def process_message(self, message: _MessageBase, channel: Channel):
332329
message,
333330
channel,
334331
):
335-
with self._run_coro(
332+
processor.result = await self._run_coro(
336333
message.message_id, handler(self, message)
337-
) as future:
338-
processor.result = await future
334+
)
339335
try:
340336
await channel.send(processor.result)
341337
except (ChannelClosed, ConnectionResetError):
@@ -464,8 +460,7 @@ async def create_actor(self, message: CreateActorMessage) -> result_message_type
464460
actor.uid = actor_id
465461
actor.address = address = self.external_address
466462
self._actors[actor_id] = actor
467-
with self._run_coro(message.message_id, actor.__post_create__()) as future:
468-
await future
463+
await self._run_coro(message.message_id, actor.__post_create__())
469464

470465
result = ActorRef(address, actor_id)
471466
# ensemble result message
@@ -491,8 +486,7 @@ async def destroy_actor(self, message: DestroyActorMessage) -> result_message_ty
491486
actor = self._actors[actor_id]
492487
except KeyError:
493488
raise ActorNotExist(f"Actor {actor_id} does not exist")
494-
with self._run_coro(message.message_id, actor.__pre_destroy__()) as future:
495-
await future
489+
await self._run_coro(message.message_id, actor.__pre_destroy__())
496490
del self._actors[actor_id]
497491

498492
processor.result = ResultMessage(
@@ -523,8 +517,7 @@ async def send(self, message: SendMessage) -> result_message_type:
523517
if actor_id not in self._actors:
524518
raise ActorNotExist(f"Actor {actor_id} does not exist")
525519
coro = self._actors[actor_id].__on_receive__(message.content)
526-
with self._run_coro(message.message_id, coro) as future:
527-
result = await future
520+
result = await self._run_coro(message.message_id, coro)
528521
processor.result = ResultMessage(
529522
message.message_id,
530523
result,

mars/oscar/backends/ray/communication.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ async def send(self, message: Any):
174174
if self._closed.is_set(): # pragma: no cover
175175
raise ChannelClosed("Channel already closed, cannot send message")
176176
# Put ray object ref to queue
177-
await self._in_queue.put(
177+
self._in_queue.put_nowait(
178178
(
179179
message,
180180
self._peer_actor.__on_ray_recv__.remote(
@@ -255,16 +255,15 @@ async def __on_ray_recv__(self, message):
255255
"""This method will be invoked when current process is a ray actor rather than a ray driver"""
256256
self._msg_recv_counter += 1
257257
await self._in_queue.put(message)
258-
# Avoid hang when channel is closed after `self._out_queue.get()` is awaited.
259-
done, _ = await asyncio.wait(
260-
[self._out_queue.get(), self._closed.wait()],
261-
return_when=asyncio.FIRST_COMPLETED,
262-
)
258+
result_message = await self._out_queue.get()
263259
if self._closed.is_set(): # pragma: no cover
264260
raise ChannelClosed("Channel already closed")
265-
if done:
266-
result_message = await done.pop()
267-
return _ArgWrapper(result_message)
261+
return _ArgWrapper(result_message)
262+
263+
@implements(Channel.close)
264+
async def close(self):
265+
await super().close()
266+
self._out_queue.put_nowait(None)
268267

269268

270269
@register_server

mars/oscar/batch.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,12 @@ def _gen_args_kwargs_list(delays):
134134
return args_list, kwargs_list
135135

136136
async def _async_batch(self, *delays):
137-
if self.batch_func:
137+
# when there is only one call in batch, calling one-pass method
138+
# will be more efficient
139+
if len(delays) == 1:
140+
d = delays[0]
141+
return [await self.func(*d.args, **d.kwargs)]
142+
elif self.batch_func:
138143
args_list, kwargs_list = self._gen_args_kwargs_list(delays)
139144
return await self.batch_func(args_list, kwargs_list)
140145
else:

mars/oscar/core.pyx

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,9 @@ cdef class _BaseActor:
212212
return create_actor_ref(self._address, self._uid)
213213

214214
async def _handle_actor_result(self, result):
215-
cdef int result_pos
215+
cdef int idx
216216
cdef tuple res_tuple
217-
cdef list tasks, values
217+
cdef list tasks, coros, coro_poses, values
218218
cdef object coro
219219
cdef bint extract_tuple = False
220220
cdef bint cancelled = False
@@ -228,20 +228,32 @@ cdef class _BaseActor:
228228

229229
if type(result) is tuple:
230230
res_tuple = result
231-
tasks = []
231+
coros = []
232+
coro_poses = []
232233
values = []
233-
for res_item in res_tuple:
234+
for idx, res_item in enumerate(res_tuple):
234235
if is_async_generator(res_item):
235-
value = asyncio.create_task(self._run_actor_async_generator(res_item))
236-
tasks.append(value)
236+
value = self._run_actor_async_generator(res_item)
237+
coros.append(value)
238+
coro_poses.append(idx)
237239
elif inspect.isawaitable(res_item):
238-
value = asyncio.create_task(res_item)
239-
tasks.append(value)
240+
value = res_item
241+
coros.append(value)
242+
coro_poses.append(idx)
240243
else:
241244
value = res_item
242245
values.append(value)
243246

244-
if len(tasks) > 0:
247+
# when there is only one coroutine, we do not need to use
248+
# asyncio.wait as it introduces much overhead
249+
if len(coros) == 1:
250+
task_result = await coros[0]
251+
if extract_tuple:
252+
result = task_result
253+
else:
254+
result = tuple(task_result if t is coros[0] else t for t in values)
255+
elif len(coros) > 0:
256+
tasks = [asyncio.create_task(t) for t in coros]
245257
try:
246258
dones, pending = await asyncio.wait(tasks)
247259
except asyncio.CancelledError:
@@ -254,7 +266,10 @@ cdef class _BaseActor:
254266
if extract_tuple:
255267
result = list(dones)[0].result()
256268
else:
257-
result = tuple(t.result() if t in dones else t for t in values)
269+
for pos in coro_poses:
270+
task = tasks[pos]
271+
values[pos] = task.result()
272+
result = tuple(values)
258273

259274
if cancelled:
260275
# raise in case no CancelledError raised

mars/services/scheduling/supervisor/assigner.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ def _get_random_band(self, is_gpu: bool):
9292
async def assign_subtasks(self, subtasks: List[Subtask]):
9393
inp_keys = set()
9494
selected_bands = dict()
95+
96+
if not self._bands:
97+
self._update_bands(
98+
list(await self._cluster_api.get_all_bands(NodeRole.WORKER))
99+
)
100+
95101
for subtask in subtasks:
96102
is_gpu = any(c.op.gpu for c in subtask.chunk_graph)
97103
if subtask.expect_bands:
@@ -116,10 +122,6 @@ async def assign_subtasks(self, subtasks: List[Subtask]):
116122
if isinstance(indep_chunk.op, Fetch):
117123
inp_keys.add(indep_chunk.key)
118124
elif isinstance(indep_chunk.op, FetchShuffle):
119-
if not self._bands:
120-
self._update_bands(
121-
list(await self._cluster_api.get_all_bands(NodeRole.WORKER))
122-
)
123125
selected_bands[subtask.subtask_id] = [self._get_random_band(is_gpu)]
124126
break
125127

mars/services/task/supervisor/processor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,8 @@ async def wait(self, timeout: int = None):
547547
_, pending = yield asyncio.wait(fs, timeout=timeout)
548548
if not pending:
549549
raise mo.Return(self.result())
550+
else:
551+
[fut.cancel() for fut in pending]
550552

551553
async def cancel(self):
552554
if self._cur_processor:

0 commit comments

Comments
 (0)