Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from servicelib.logging_utils import log_catch, log_context
from servicelib.rabbitmq import RabbitMQClient, RabbitMQRPCClient
from servicelib.redis import RedisClientSDK
from servicelib.utils import limited_gather
from sqlalchemy.ext.asyncio import AsyncEngine

from ...constants import UNDEFINED_STR_METADATA
Expand Down Expand Up @@ -79,6 +80,7 @@
_MAX_WAITING_TIME_FOR_UNKNOWN_TASKS: Final[datetime.timedelta] = datetime.timedelta(
seconds=30
)
_PUBLICATION_CONCURRENCY_LIMIT: Final[int] = 10


def _auto_schedule_callback(
Expand Down Expand Up @@ -336,7 +338,7 @@ def _need_heartbeat(task: CompTaskAtDB) -> bool:
project_id, dag
)
if running_tasks := [t for t in tasks.values() if _need_heartbeat(t)]:
await asyncio.gather(
await limited_gather(
*(
publish_service_resource_tracking_heartbeat(
self.rabbitmq_client,
Expand All @@ -345,17 +347,15 @@ def _need_heartbeat(task: CompTaskAtDB) -> bool:
),
)
for t in running_tasks
)
),
log=_logger,
limit=_PUBLICATION_CONCURRENCY_LIMIT,
)
comp_tasks_repo = CompTasksRepository(self.db_engine)
await asyncio.gather(
*(
comp_tasks_repo.update_project_task_last_heartbeat(
t.project_id, t.node_id, run_id, utc_now
)
for t in running_tasks
comp_tasks_repo = CompTasksRepository.instance(self.db_engine)
for task in running_tasks:
await comp_tasks_repo.update_project_task_last_heartbeat(
project_id, task.node_id, run_id, utc_now
)
)

async def _get_changed_tasks_from_backend(
self,
Expand Down Expand Up @@ -400,7 +400,7 @@ async def _process_started_tasks(
utc_now = arrow.utcnow().datetime

# resource tracking
await asyncio.gather(
await limited_gather(
*(
publish_service_resource_tracking_started(
self.rabbitmq_client,
Expand Down Expand Up @@ -462,10 +462,12 @@ async def _process_started_tasks(
service_additional_metadata={},
)
for t in tasks
)
),
log=_logger,
limit=_PUBLICATION_CONCURRENCY_LIMIT,
)
# instrumentation
await asyncio.gather(
await limited_gather(
*(
publish_service_started_metrics(
self.rabbitmq_client,
Expand All @@ -476,24 +478,22 @@ async def _process_started_tasks(
task=t,
)
for t in tasks
)
),
log=_logger,
limit=_PUBLICATION_CONCURRENCY_LIMIT,
)

# update DB
comp_tasks_repo = CompTasksRepository(self.db_engine)
await asyncio.gather(
*(
comp_tasks_repo.update_project_tasks_state(
t.project_id,
run_id,
[t.node_id],
t.state,
optional_started=utc_now,
optional_progress=t.progress,
)
for t in tasks
for task in tasks:
await comp_tasks_repo.update_project_tasks_state(
project_id,
run_id,
[task.node_id],
task.state,
optional_started=utc_now,
optional_progress=task.progress,
)
)
await CompRunsRepository.instance(self.db_engine).mark_as_started(
user_id=user_id,
project_id=project_id,
Expand All @@ -504,18 +504,14 @@ async def _process_started_tasks(
async def _process_waiting_tasks(
self, tasks: list[TaskStateTracker], run_id: PositiveInt
) -> None:
comp_tasks_repo = CompTasksRepository(self.db_engine)
await asyncio.gather(
*(
comp_tasks_repo.update_project_tasks_state(
t.current.project_id,
run_id,
[t.current.node_id],
t.current.state,
)
for t in tasks
comp_tasks_repo = CompTasksRepository.instance(self.db_engine)
for task in tasks:
await comp_tasks_repo.update_project_tasks_state(
task.current.project_id,
run_id,
[task.current.node_id],
task.current.state,
)
)

async def _update_states_from_comp_backend(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,12 @@
publish_service_stopped_metrics,
)
from ..clusters_keeper import get_or_create_on_demand_cluster
from ..dask_client import DaskClient, PublishedComputationTask
from ..dask_client import DaskClient
from ..dask_clients_pool import DaskClientsPool
from ..db.repositories.comp_runs import (
CompRunsRepository,
)
from ..db.repositories.comp_tasks import CompTasksRepository
from ._constants import (
MAX_CONCURRENT_PIPELINE_SCHEDULING,
)
from ._models import TaskStateTracker
from ._scheduler_base import BaseCompScheduler
from ._utils import (
Expand All @@ -68,6 +65,7 @@
_DASK_CLIENT_RUN_REF: Final[str] = "{user_id}:{project_id}:{run_id}"
_TASK_RETRIEVAL_ERROR_TYPE: Final[str] = "task-result-retrieval-timeout"
_TASK_RETRIEVAL_ERROR_CONTEXT_TIME_KEY: Final[str] = "check_time"
_PUBLICATION_CONCURRENCY_LIMIT: Final[int] = 10


@asynccontextmanager
Expand Down Expand Up @@ -149,37 +147,31 @@ async def _start_tasks(
RunningState.PENDING,
)
# each task is started independently
results: list[list[PublishedComputationTask]] = await limited_gather(
*(
client.send_computation_tasks(
user_id=user_id,
project_id=project_id,
tasks={node_id: task.image},
hardware_info=task.hardware_info,
callback=wake_up_callback,
metadata=comp_run.metadata,
resource_tracking_run_id=ServiceRunID.get_resource_tracking_run_id_for_computational(
user_id, project_id, node_id, comp_run.iteration
),
)
for node_id, task in scheduled_tasks.items()
),
log=_logger,
limit=MAX_CONCURRENT_PIPELINE_SCHEDULING,
)

# update the database so we do have the correct job_ids there
await limited_gather(
*(
comp_tasks_repo.update_project_task_job_id(
project_id, task.node_id, comp_run.run_id, task.job_id
)
for task_sents in results
for task in task_sents
),
log=_logger,
limit=MAX_CONCURRENT_PIPELINE_SCHEDULING,
)
for node_id, task in scheduled_tasks.items():
published_tasks = await client.send_computation_tasks(
user_id=user_id,
project_id=project_id,
tasks={node_id: task.image},
hardware_info=task.hardware_info,
callback=wake_up_callback,
metadata=comp_run.metadata,
resource_tracking_run_id=ServiceRunID.get_resource_tracking_run_id_for_computational(
user_id, project_id, node_id, comp_run.iteration
),
)

# update the database so we do have the correct job_ids there
await limited_gather(
*(
comp_tasks_repo.update_project_task_job_id(
project_id, task.node_id, comp_run.run_id, task.job_id
)
for task in published_tasks
),
log=_logger,
limit=1,
)

async def _get_tasks_status(
self,
Expand Down Expand Up @@ -208,7 +200,7 @@ async def _process_executing_tasks(
tasks: list[CompTaskAtDB],
comp_run: CompRunsAtDB,
) -> None:
task_progresses = []
task_progress_events = []
try:
async with _cluster_dask_client(
user_id,
Expand All @@ -218,42 +210,33 @@ async def _process_executing_tasks(
run_id=comp_run.run_id,
run_metadata=comp_run.metadata,
) as client:
task_progresses = [
task_progress_events = [
t
for t in await client.get_tasks_progress(
[f"{t.job_id}" for t in tasks],
)
if t is not None
]
await limited_gather(
*(
CompTasksRepository(self.db_engine).update_project_task_progress(
t.task_owner.project_id,
t.task_owner.node_id,
comp_run.run_id,
t.progress,
)
for t in task_progresses
),
log=_logger,
limit=MAX_CONCURRENT_PIPELINE_SCHEDULING,
)
for progress_event in task_progress_events:
await CompTasksRepository(self.db_engine).update_project_task_progress(
progress_event.task_owner.project_id,
progress_event.task_owner.node_id,
comp_run.run_id,
progress_event.progress,
)

except ComputationalBackendOnDemandNotReadyError:
_logger.info("The on demand computational backend is not ready yet...")

comp_tasks_repo = CompTasksRepository(self.db_engine)
for task in task_progress_events:
await comp_tasks_repo.update_project_task_progress(
task.task_owner.project_id,
task.task_owner.node_id,
comp_run.run_id,
task.progress,
)
await limited_gather(
*(
comp_tasks_repo.update_project_task_progress(
t.task_owner.project_id,
t.task_owner.node_id,
comp_run.run_id,
t.progress,
)
for t in task_progresses
if t
),
*(
publish_service_progress(
self.rabbitmq_client,
Expand All @@ -262,11 +245,10 @@ async def _process_executing_tasks(
node_id=t.task_owner.node_id,
progress=t.progress,
)
for t in task_progresses
if t
for t in task_progress_events
),
log=_logger,
limit=MAX_CONCURRENT_PIPELINE_SCHEDULING,
limit=_PUBLICATION_CONCURRENCY_LIMIT,
)

async def _release_resources(self, comp_run: CompRunsAtDB) -> None:
Expand Down Expand Up @@ -300,25 +282,14 @@ async def _stop_tasks(
run_id=comp_run.run_id,
run_metadata=comp_run.metadata,
) as client:
await limited_gather(
*(
client.abort_computation_task(t.job_id)
for t in tasks
if t.job_id
),
log=_logger,
limit=MAX_CONCURRENT_PIPELINE_SCHEDULING,
)
# tasks that have no-worker must be unpublished as these are blocking forever
await limited_gather(
*(
client.release_task_result(t.job_id)
for t in tasks
if t.state is RunningState.WAITING_FOR_RESOURCES and t.job_id
),
log=_logger,
limit=MAX_CONCURRENT_PIPELINE_SCHEDULING,
)
for t in tasks:
if not t.job_id:
_logger.warning("%s has no job_id, cannot be stopped", t)
continue
await client.abort_computation_task(t.job_id)
# tasks that have no-worker must be unpublished as these are blocking forever
if t.state is RunningState.WAITING_FOR_RESOURCES:
await client.release_task_result(t.job_id)

async def _process_completed_tasks(
self,
Expand All @@ -342,7 +313,7 @@ async def _process_completed_tasks(
),
reraise=False,
log=_logger,
limit=MAX_CONCURRENT_PIPELINE_SCHEDULING,
limit=1, # to avoid overloading the dask scheduler
)
async for future in limited_as_completed(
(
Expand All @@ -354,7 +325,7 @@ async def _process_completed_tasks(
)
for task, result in zip(tasks, tasks_results, strict=True)
),
limit=MAX_CONCURRENT_PIPELINE_SCHEDULING,
limit=10, # this is not accessing the dask-scheduelr (only db)
):
with log_catch(_logger, reraise=False):
task_can_be_cleaned, job_id = await future
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

"""

import asyncio
import logging
from collections.abc import Callable, Iterable
from dataclasses import dataclass
Expand Down Expand Up @@ -241,9 +240,6 @@ def _comp_sidecar_fct(
)
# NOTE: the callback is running in a secondary thread, and takes a future as arg
task_future.add_done_callback(lambda _: callback())
await distributed.Variable(job_id, client=self.backend.client).set(
task_future
)

await dask_utils.wrap_client_async_routine(
self.backend.client.publish_dataset(task_future, name=job_id)
Expand Down Expand Up @@ -560,12 +556,6 @@ async def get_task_result(self, job_id: str) -> TaskOutputData:
async def release_task_result(self, job_id: str) -> None:
_logger.debug("releasing results for %s", f"{job_id=}")
try:
# NOTE: The distributed Variable holds the future of the tasks in the dask-scheduler
# Alas, deleting the variable is done asynchronously and there is no way to ensure
# the variable was effectively deleted.
# This is annoying as one can re-create the variable without error.
var = distributed.Variable(job_id, client=self.backend.client)
await asyncio.get_event_loop().run_in_executor(None, var.delete)
# first check if the key exists
await dask_utils.wrap_client_async_routine(
self.backend.client.get_dataset(name=job_id)
Expand Down
Loading