|
24 | 24 | from .... import oscar as mo |
25 | 25 | from ....core.operand import Fetch |
26 | 26 | from ....lib.aio import alru_cache |
27 | | -from ....oscar.backends.message import ProfilingContext |
28 | 27 | from ....oscar.errors import MarsError |
29 | 28 | from ....typing import BandType |
30 | 29 | from ....utils import dataslots |
@@ -95,8 +94,7 @@ async def __post_create__(self): |
95 | 94 | AssignerActor.gen_uid(self._session_id), address=self.address |
96 | 95 | ) |
97 | 96 |
|
98 | | - @alru_cache |
99 | | - async def _get_task_api(self): |
| 97 | + async def _get_task_api(self) -> TaskAPI: |
100 | 98 | return await TaskAPI.create(self._session_id, self.address) |
101 | 99 |
|
102 | 100 | def _put_subtask_with_priority(self, subtask: Subtask, priority: Tuple = None): |
@@ -272,21 +270,47 @@ async def update_subtask_priorities( |
272 | 270 |
|
273 | 271 | @alru_cache(maxsize=10000) |
274 | 272 | async def _get_execution_ref(self, address: str): |
275 | | - from ..worker.exec import SubtaskExecutionActor |
| 273 | + from ..worker.execution import SubtaskExecutionActor |
276 | 274 |
|
277 | 275 | return await mo.actor_ref(SubtaskExecutionActor.default_uid(), address=address) |
278 | 276 |
|
279 | | - async def finish_subtasks(self, subtask_ids: List[str], schedule_next: bool = True): |
280 | | - band_tasks = defaultdict(lambda: 0) |
281 | | - for subtask_id in subtask_ids: |
282 | | - subtask_info = self._subtask_infos.pop(subtask_id, None) |
| 277 | + async def set_subtask_results( |
| 278 | + self, subtask_results: List[SubtaskResult], source_bands: List[BandType] |
| 279 | + ): |
| 280 | + delays = [] |
| 281 | + task_api = await self._get_task_api() |
| 282 | + for result, band in zip(subtask_results, source_bands): |
| 283 | + if result.status == SubtaskStatus.errored: |
| 284 | + subtask_info = self._subtask_infos.get(result.subtask_id) |
| 285 | + if ( |
| 286 | + subtask_info is not None |
| 287 | + and subtask_info.subtask.retryable |
| 288 | + and subtask_info.num_reschedules < subtask_info.max_reschedules |
| 289 | + and isinstance(result.error, (MarsError, OSError)) |
| 290 | + ): |
| 291 | + subtask_info.num_reschedules += 1 |
| 292 | + logger.warning( |
| 293 | + "Resubmit subtask %s at attempt %d", |
| 294 | + subtask_info.subtask.subtask_id, |
| 295 | + subtask_info.num_reschedules, |
| 296 | + ) |
| 297 | + execution_ref = await self._get_execution_ref(band[0]) |
| 298 | + await execution_ref.submit_subtasks.tell( |
| 299 | + [subtask_info.subtask], |
| 300 | + [subtask_info.priority], |
| 301 | + self.address, |
| 302 | + band[1], |
| 303 | + ) |
| 304 | + continue |
| 305 | + |
| 306 | + subtask_info = self._subtask_infos.pop(result.subtask_id, None) |
283 | 307 | if subtask_info is not None: |
284 | | - self._subtask_summaries[subtask_id] = subtask_info.to_summary( |
| 308 | + self._subtask_summaries[result.subtask_id] = subtask_info.to_summary( |
285 | 309 | is_finished=True |
286 | 310 | ) |
287 | | - if schedule_next: |
288 | | - for band in subtask_info.submitted_bands: |
289 | | - band_tasks[band] += 1 |
| 311 | + delays.append(task_api.set_subtask_result.delay(result)) |
| 312 | + |
| 313 | + await task_api.set_subtask_result.batch(*delays) |
290 | 314 |
|
291 | 315 | def _get_subtasks_by_ids(self, subtask_ids: List[str]) -> List[Optional[Subtask]]: |
292 | 316 | subtasks = [] |
|
0 commit comments