Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
from models_library.rabbitmq_basic_types import RPCMethodName
from models_library.users import UserID
from models_library.wallets import WalletID
from pydantic import TypeAdapter

from ....async_utils import run_sequentially_in_context
from ..._client_rpc import RabbitMQRPCClient
from ..._constants import RPC_REMOTE_METHOD_TIMEOUT_S

_TTL_CACHE_ON_CLUSTERS_S: Final[int] = 5

_GET_OR_CREATE_CLUSTER_METHOD_NAME: Final[RPCMethodName] = TypeAdapter(
RPCMethodName
).validate_python("get_or_create_cluster")


@run_sequentially_in_context(target_args=["user_id", "wallet_id"])
@cached(
Expand All @@ -32,7 +37,7 @@ async def get_or_create_cluster(
# the 2nd decorator ensure that many calls in a short time will return quickly the same value
on_demand_cluster: OnDemandCluster = await client.request(
CLUSTERS_KEEPER_RPC_NAMESPACE,
RPCMethodName("get_or_create_cluster"),
_GET_OR_CREATE_CLUSTER_METHOD_NAME,
timeout_s=RPC_REMOTE_METHOD_TIMEOUT_S,
user_id=user_id,
wallet_id=wallet_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,12 @@ async def _process_executing_tasks(
) -> None:
"""process executing tasks from the 3rd party backend"""

@abstractmethod
async def _release_resources(
self, user_id: UserID, project_id: ProjectID, comp_run: CompRunsAtDB
) -> None:
"""release resources used by the scheduler for a given user and project"""

async def apply(
self,
*,
Expand Down Expand Up @@ -654,6 +660,7 @@ async def apply(

# 7. Are we done scheduling that pipeline?
if not dag.nodes() or pipeline_result in COMPLETED_STATES:
await self._release_resources(user_id, project_id, comp_run)
# there is nothing left, the run is completed, we're done here
_logger.info(
"pipeline %s scheduling completed with result %s",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import AsyncIterator, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any
from typing import Any, Final

import arrow
from dask_task_models_library.container_tasks.errors import TaskCancelledError
Expand All @@ -23,7 +23,7 @@
from models_library.users import UserID
from pydantic import PositiveInt
from servicelib.common_headers import UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE
from servicelib.logging_utils import log_catch
from servicelib.logging_utils import log_catch, log_context

from ...core.errors import (
ComputationalBackendNotConnectedError,
Expand Down Expand Up @@ -56,13 +56,16 @@

_logger = logging.getLogger(__name__)

_DASK_CLIENT_RUN_REF: Final[str] = "{user_id}:{run_id}"


@asynccontextmanager
async def _cluster_dask_client(
user_id: UserID,
scheduler: "DaskScheduler",
*,
use_on_demand_clusters: bool,
run_id: PositiveInt,
run_metadata: RunMetadataDict,
) -> AsyncIterator[DaskClient]:
cluster: BaseCluster = scheduler.settings.default_cluster
Expand All @@ -72,7 +75,9 @@ async def _cluster_dask_client(
user_id=user_id,
wallet_id=run_metadata.get("wallet_id"),
)
async with scheduler.dask_clients_pool.acquire(cluster) as client:
async with scheduler.dask_clients_pool.acquire(
cluster, ref=_DASK_CLIENT_RUN_REF.format(user_id=user_id, run_id=run_id)
) as client:
yield client


Expand Down Expand Up @@ -101,6 +106,7 @@ async def _start_tasks(
user_id,
self,
use_on_demand_clusters=comp_run.use_on_demand_clusters,
run_id=comp_run.run_id,
run_metadata=comp_run.metadata,
) as client:
# Change the tasks state to PENDING
Expand Down Expand Up @@ -151,6 +157,7 @@ async def _get_tasks_status(
user_id,
self,
use_on_demand_clusters=comp_run.use_on_demand_clusters,
run_id=comp_run.run_id,
run_metadata=comp_run.metadata,
) as client:
return await client.get_tasks_status([f"{t.job_id}" for t in tasks])
Expand All @@ -171,6 +178,7 @@ async def _process_executing_tasks(
user_id,
self,
use_on_demand_clusters=comp_run.use_on_demand_clusters,
run_id=comp_run.run_id,
run_metadata=comp_run.metadata,
) as client:
task_progresses = await client.get_tasks_progress(
Expand Down Expand Up @@ -217,6 +225,22 @@ async def _process_executing_tasks(
)
)

async def _release_resources(
self, user_id: UserID, project_id: ProjectID, comp_run: CompRunsAtDB
) -> None:
"""release resources used by the scheduler for a given user and project"""
with (
log_catch(_logger, reraise=False),
log_context(
_logger,
logging.INFO,
msg=f"releasing resources for {user_id=}, {project_id=}, {comp_run.run_id=}",
),
):
await self.dask_clients_pool.release_client_ref(
ref=_DASK_CLIENT_RUN_REF.format(user_id=user_id, run_id=comp_run.run_id)
)

async def _stop_tasks(
self, user_id: UserID, tasks: list[CompTaskAtDB], comp_run: CompRunsAtDB
) -> None:
Expand All @@ -226,6 +250,7 @@ async def _stop_tasks(
user_id,
self,
use_on_demand_clusters=comp_run.use_on_demand_clusters,
run_id=comp_run.run_id,
run_metadata=comp_run.metadata,
) as client:
await asyncio.gather(
Expand Down Expand Up @@ -259,6 +284,7 @@ async def _process_completed_tasks(
user_id,
self,
use_on_demand_clusters=comp_run.use_on_demand_clusters,
run_id=comp_run.run_id,
run_metadata=comp_run.metadata,
) as client:
tasks_results = await asyncio.gather(
Expand All @@ -278,6 +304,7 @@ async def _process_completed_tasks(
user_id,
self,
use_on_demand_clusters=comp_run.use_on_demand_clusters,
run_id=comp_run.run_id,
run_metadata=comp_run.metadata,
) as client:
await asyncio.gather(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,43 +120,42 @@ async def create(
tasks_file_link_type: FileLinkType,
cluster_type: ClusterTypeInModel,
) -> "DaskClient":
_logger.info(
"Initiating connection to %s with auth: %s, type: %s",
f"dask-scheduler at {endpoint}",
authentication,
cluster_type,
)
async for attempt in AsyncRetrying(
reraise=True,
before_sleep=before_sleep_log(_logger, logging.INFO),
wait=wait_fixed(0.3),
stop=stop_after_attempt(3),
with log_context(
_logger,
logging.INFO,
msg=f"create dask client to dask-scheduler at {endpoint=} with {authentication=}, {cluster_type=}",
):
with attempt:
_logger.debug(
"Connecting to %s, attempt %s...",
endpoint,
attempt.retry_state.attempt_number,
)
backend = await connect_to_dask_scheduler(endpoint, authentication)
dask_utils.check_scheduler_status(backend.client)
instance = cls(
app=app,
backend=backend,
settings=settings,
tasks_file_link_type=tasks_file_link_type,
cluster_type=cluster_type,
)
_logger.info(
"Connection to %s succeeded [%s]",
f"dask-scheduler at {endpoint}",
json_dumps(attempt.retry_state.retry_object.statistics),
)
_logger.info(
"Scheduler info:\n%s",
json_dumps(backend.client.scheduler_info(), indent=2),
)
return instance
async for attempt in AsyncRetrying(
reraise=True,
before_sleep=before_sleep_log(_logger, logging.INFO),
wait=wait_fixed(0.3),
stop=stop_after_attempt(3),
):
with attempt:
_logger.debug(
"Connecting to %s, attempt %s...",
endpoint,
attempt.retry_state.attempt_number,
)
backend = await connect_to_dask_scheduler(endpoint, authentication)
dask_utils.check_scheduler_status(backend.client)
instance = cls(
app=app,
backend=backend,
settings=settings,
tasks_file_link_type=tasks_file_link_type,
cluster_type=cluster_type,
)
_logger.info(
"Connection to %s succeeded [%s]",
f"dask-scheduler at {endpoint}",
json_dumps(attempt.retry_state.retry_object.statistics),
)
_logger.info(
"Scheduler info:\n%s",
json_dumps(backend.client.scheduler_info(), indent=2),
)
return instance
# this is to satisfy pylance
err_msg = "Could not create client"
raise ValueError(err_msg)
Expand Down
Loading
Loading