88from fastapi import FastAPI
99from models_library .clusters import BaseCluster , ClusterTypeInModel
1010from pydantic import AnyUrl
11+ from servicelib .logging_utils import log_context
1112
1213from ..core .errors import (
1314 ComputationalBackendNotConnectedError ,
1920from ..utils .dask_client_utils import TaskHandlers
2021from .dask_client import DaskClient
2122
22- logger = logging .getLogger (__name__ )
23+ _logger = logging .getLogger (__name__ )
2324
2425
2526_ClusterUrl : TypeAlias = AnyUrl
@@ -62,48 +63,51 @@ async def delete(self) -> None:
6263
6364 @asynccontextmanager
6465 async def acquire (self , cluster : BaseCluster ) -> AsyncIterator [DaskClient ]:
66+ """returns a dask client for the given cluster
67+ This method is thread-safe and can be called concurrently.
68+ If the cluster is not found in the pool, it will create a new dask client for it.
69+
70+ """
71+
6572 async def _concurently_safe_acquire_client () -> DaskClient :
6673 async with self ._client_acquisition_lock :
67- dask_client = self ._cluster_to_client_map .get (cluster .endpoint )
68-
69- # we create a new client if that cluster was never used before
70- logger .debug (
71- "acquiring connection to cluster %s:%s" ,
72- cluster .endpoint ,
73- cluster .name ,
74- )
75- if not dask_client :
76- tasks_file_link_type = (
77- self .settings .COMPUTATIONAL_BACKEND_DEFAULT_FILE_LINK_TYPE
78- )
79- if cluster == self .settings .default_cluster :
74+ with log_context (
75+ _logger ,
76+ logging .DEBUG ,
77+ f"acquire dask client for { cluster .name = } :{ cluster .endpoint } " ,
78+ ):
79+ dask_client = self ._cluster_to_client_map .get (cluster .endpoint )
80+ if not dask_client :
8081 tasks_file_link_type = (
81- self .settings .COMPUTATIONAL_BACKEND_DEFAULT_CLUSTER_FILE_LINK_TYPE
82+ self .settings .COMPUTATIONAL_BACKEND_DEFAULT_FILE_LINK_TYPE
8283 )
83- if cluster .type == ClusterTypeInModel .ON_DEMAND .value :
84- tasks_file_link_type = (
85- self .settings .COMPUTATIONAL_BACKEND_ON_DEMAND_CLUSTERS_FILE_LINK_TYPE
84+ if cluster == self .settings .default_cluster :
85+ tasks_file_link_type = (
86+ self .settings .COMPUTATIONAL_BACKEND_DEFAULT_CLUSTER_FILE_LINK_TYPE
87+ )
88+ if cluster .type == ClusterTypeInModel .ON_DEMAND .value :
89+ tasks_file_link_type = (
90+ self .settings .COMPUTATIONAL_BACKEND_ON_DEMAND_CLUSTERS_FILE_LINK_TYPE
91+ )
92+ self ._cluster_to_client_map [cluster .endpoint ] = dask_client = (
93+ await DaskClient .create (
94+ app = self .app ,
95+ settings = self .settings ,
96+ endpoint = cluster .endpoint ,
97+ authentication = cluster .authentication ,
98+ tasks_file_link_type = tasks_file_link_type ,
99+ cluster_type = cluster .type ,
100+ )
86101 )
87- self ._cluster_to_client_map [
88- cluster .endpoint
89- ] = dask_client = await DaskClient .create (
90- app = self .app ,
91- settings = self .settings ,
92- endpoint = cluster .endpoint ,
93- authentication = cluster .authentication ,
94- tasks_file_link_type = tasks_file_link_type ,
95- cluster_type = cluster .type ,
96- )
97- if self ._task_handlers :
98- dask_client .register_handlers (self ._task_handlers )
99-
100- logger .debug ("created new client to cluster %s" , f"{ cluster = } " )
101- logger .debug (
102- "list of clients: %s" , f"{ self ._cluster_to_client_map = } "
103- )
104-
105- assert dask_client # nosec
106- return dask_client
102+ if self ._task_handlers :
103+ dask_client .register_handlers (self ._task_handlers )
104+
105+ _logger .debug (
106+ "list of clients: %s" , f"{ self ._cluster_to_client_map = } "
107+ )
108+
109+ assert dask_client # nosec
110+ return dask_client
107111
108112 try :
109113 dask_client = await _concurently_safe_acquire_client ()
@@ -129,7 +133,7 @@ async def on_startup() -> None:
129133 app = app , settings = settings
130134 )
131135
132- logger .info (
136+ _logger .info (
133137 "Default cluster is set to %s" ,
134138 f"{ settings .default_cluster !r} " ,
135139 )
0 commit comments