Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@

_logger = logging.getLogger(__name__)

_DASK_CLIENT_RUN_REF: Final[str] = "{user_id}:{run_id}"
_DASK_CLIENT_RUN_REF: Final[str] = "{user_id}:{project_id}:{run_id}"


@asynccontextmanager
Expand All @@ -65,6 +65,7 @@ async def _cluster_dask_client(
scheduler: "DaskScheduler",
*,
use_on_demand_clusters: bool,
project_id: ProjectID,
run_id: PositiveInt,
run_metadata: RunMetadataDict,
) -> AsyncIterator[DaskClient]:
Expand All @@ -76,7 +77,10 @@ async def _cluster_dask_client(
wallet_id=run_metadata.get("wallet_id"),
)
async with scheduler.dask_clients_pool.acquire(
cluster, ref=_DASK_CLIENT_RUN_REF.format(user_id=user_id, run_id=run_id)
cluster,
ref=_DASK_CLIENT_RUN_REF.format(
user_id=user_id, project_id=project_id, run_id=run_id
),
) as client:
yield client

Expand Down Expand Up @@ -106,6 +110,7 @@ async def _start_tasks(
user_id,
self,
use_on_demand_clusters=comp_run.use_on_demand_clusters,
project_id=comp_run.project_uuid,
run_id=comp_run.run_id,
run_metadata=comp_run.metadata,
) as client:
Expand Down Expand Up @@ -157,6 +162,7 @@ async def _get_tasks_status(
user_id,
self,
use_on_demand_clusters=comp_run.use_on_demand_clusters,
project_id=comp_run.project_uuid,
run_id=comp_run.run_id,
run_metadata=comp_run.metadata,
) as client:
Expand All @@ -178,6 +184,7 @@ async def _process_executing_tasks(
user_id,
self,
use_on_demand_clusters=comp_run.use_on_demand_clusters,
project_id=comp_run.project_uuid,
run_id=comp_run.run_id,
run_metadata=comp_run.metadata,
) as client:
Expand Down Expand Up @@ -238,7 +245,9 @@ async def _release_resources(
),
):
await self.dask_clients_pool.release_client_ref(
ref=_DASK_CLIENT_RUN_REF.format(user_id=user_id, run_id=comp_run.run_id)
ref=_DASK_CLIENT_RUN_REF.format(
user_id=user_id, project_id=project_id, run_id=comp_run.run_id
)
)

async def _stop_tasks(
Expand All @@ -250,6 +259,7 @@ async def _stop_tasks(
user_id,
self,
use_on_demand_clusters=comp_run.use_on_demand_clusters,
project_id=comp_run.project_uuid,
run_id=comp_run.run_id,
run_metadata=comp_run.metadata,
) as client:
Expand Down Expand Up @@ -284,6 +294,7 @@ async def _process_completed_tasks(
user_id,
self,
use_on_demand_clusters=comp_run.use_on_demand_clusters,
project_id=comp_run.project_uuid,
run_id=comp_run.run_id,
run_metadata=comp_run.metadata,
) as client:
Expand All @@ -304,6 +315,7 @@ async def _process_completed_tasks(
user_id,
self,
use_on_demand_clusters=comp_run.use_on_demand_clusters,
project_id=comp_run.project_uuid,
run_id=comp_run.run_id,
run_metadata=comp_run.metadata,
) as client:
Expand Down
Loading