1212 TaskProgressEvent ,
1313)
1414from dask_task_models_library .container_tasks .io import TaskOutputData
15- from dask_task_models_library .container_tasks .utils import parse_dask_job_id
1615from models_library .clusters import BaseCluster
1716from models_library .errors import ErrorDict
1817from models_library .projects import ProjectID
2423from pydantic import PositiveInt
2524from servicelib .common_headers import UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE
2625from servicelib .logging_utils import log_catch , log_context
27- from servicelib .redis ._client import RedisClientSDK
2826from servicelib .redis ._semaphore_decorator import (
2927 with_limited_concurrency_cm ,
3028)
7270_TASK_RETRIEVAL_ERROR_CONTEXT_TIME_KEY : Final [str ] = "check_time"
7371
7472
75- def _get_redis_client_from_scheduler (
76- _user_id : UserID ,
77- scheduler : "DaskScheduler" ,
78- ** kwargs , # pylint: disable=unused-argument # noqa: ARG001
79- ) -> RedisClientSDK :
80- return scheduler .redis_client
81-
82-
83- def _get_semaphore_cluster_redis_key (
84- user_id : UserID ,
85- * args , # pylint: disable=unused-argument # noqa: ARG001
86- run_metadata : RunMetadataDict ,
87- ** kwargs , # pylint: disable=unused-argument # noqa: ARG001
88- ) -> str :
89- return f"{ APP_NAME } -cluster-user_id_{ user_id } -wallet_id_{ run_metadata .get ('wallet_id' )} "
90-
91-
92- def _get_semaphore_capacity_from_scheduler (
93- _user_id : UserID ,
94- scheduler : "DaskScheduler" ,
95- ** kwargs , # pylint: disable=unused-argument # noqa: ARG001
96- ) -> int :
97- return (
98- scheduler .settings .COMPUTATIONAL_BACKEND_PER_CLUSTER_MAX_DISTRIBUTED_CONCURRENT_CONNECTIONS
99- )
100-
101-
102- @with_limited_concurrency_cm (
103- _get_redis_client_from_scheduler ,
104- key = _get_semaphore_cluster_redis_key ,
105- capacity = _get_semaphore_capacity_from_scheduler ,
106- blocking = True ,
107- blocking_timeout = None ,
108- )
10973@asynccontextmanager
11074async def _cluster_dask_client (
11175 user_id : UserID ,
@@ -123,12 +87,27 @@ async def _cluster_dask_client(
12387 user_id = user_id ,
12488 wallet_id = run_metadata .get ("wallet_id" ),
12589 )
126- async with scheduler .dask_clients_pool .acquire (
127- cluster ,
128- ref = _DASK_CLIENT_RUN_REF .format (
129- user_id = user_id , project_id = project_id , run_id = run_id
130- ),
131- ) as client :
90+
91+ @with_limited_concurrency_cm (
92+ scheduler .redis_client ,
93+ key = f"{ APP_NAME } -cluster-user_id_{ user_id } -wallet_id_{ run_metadata .get ('wallet_id' )} " ,
94+ capacity = scheduler .settings .COMPUTATIONAL_BACKEND_PER_CLUSTER_MAX_DISTRIBUTED_CONCURRENT_CONNECTIONS ,
95+ blocking = True ,
96+ blocking_timeout = None ,
97+ )
98+ @asynccontextmanager
99+ async def _acquire_client (
100+ user_id : UserID , scheduler : "DaskScheduler"
101+ ) -> AsyncIterator [DaskClient ]:
102+ async with scheduler .dask_clients_pool .acquire (
103+ cluster ,
104+ ref = _DASK_CLIENT_RUN_REF .format (
105+ user_id = user_id , project_id = project_id , run_id = run_id
106+ ),
107+ ) as client :
108+ yield client
109+
110+ async with _acquire_client (user_id , scheduler ) as client :
132111 yield client
133112
134113
@@ -370,17 +349,16 @@ async def _process_completed_tasks(
370349 self ._process_task_result (
371350 task ,
372351 result ,
373- comp_run .metadata ,
374352 iteration ,
375- comp_run . run_id ,
353+ comp_run ,
376354 )
377355 for task , result in zip (tasks , tasks_results , strict = True )
378356 ),
379357 limit = MAX_CONCURRENT_PIPELINE_SCHEDULING ,
380358 ):
381359 with log_catch (_logger , reraise = False ):
382360 task_can_be_cleaned , job_id = await future
383- if task_can_be_cleaned :
361+ if task_can_be_cleaned and job_id :
384362 await client .release_task_result (job_id )
385363
386364 async def _handle_successful_run (
@@ -411,11 +389,9 @@ async def _handle_successful_run(
411389 async def _handle_computational_retrieval_error (
412390 self ,
413391 task : CompTaskAtDB ,
414- user_id : UserID ,
415392 result : ComputationalBackendTaskResultsNotReadyError ,
416393 log_error_context : dict [str , Any ],
417394 ) -> tuple [RunningState , SimcorePlatformStatus , list [ErrorDict ], bool ]:
418- assert task .job_id # nosec
419395 _logger .warning (
420396 ** create_troubleshooting_log_kwargs (
421397 f"Retrieval of task { task .job_id } result timed-out" ,
@@ -448,10 +424,7 @@ async def _handle_computational_retrieval_error(
448424 type = _TASK_RETRIEVAL_ERROR_TYPE ,
449425 ctx = {
450426 _TASK_RETRIEVAL_ERROR_CONTEXT_TIME_KEY : f"{ check_time } " ,
451- "user_id" : user_id ,
452- "project_id" : f"{ task .project_id } " ,
453- "node_id" : f"{ task .node_id } " ,
454- "job_id" : task .job_id ,
427+ ** log_error_context ,
455428 },
456429 )
457430 )
@@ -472,7 +445,6 @@ async def _handle_computational_backend_not_connected_error(
472445 result : ComputationalBackendNotConnectedError ,
473446 log_error_context : dict [str , Any ],
474447 ) -> tuple [RunningState , SimcorePlatformStatus , list [ErrorDict ], bool ]:
475- assert task .job_id # nosec
476448 _logger .warning (
477449 ** create_troubleshooting_log_kwargs (
478450 f"Computational backend disconnected when retrieving task { task .job_id } result" ,
@@ -492,8 +464,6 @@ async def _handle_task_error(
492464 result : BaseException ,
493465 log_error_context : dict [str , Any ],
494466 ) -> tuple [RunningState , SimcorePlatformStatus , list [ErrorDict ], bool ]:
495- assert task .job_id # nosec
496-
497467 # the task itself failed, check why
498468 if isinstance (result , TaskCancelledError ):
499469 _logger .info (
@@ -529,102 +499,100 @@ async def _process_task_result(
529499 self ,
530500 task : TaskStateTracker ,
531501 result : BaseException | TaskOutputData ,
532- run_metadata : RunMetadataDict ,
533502 iteration : Iteration ,
534- run_id : PositiveInt ,
535- ) -> tuple [bool , str ]:
503+ comp_run : CompRunsAtDB ,
504+ ) -> tuple [bool , str | None ]:
536505 """Returns True and the job ID if the task was successfully processed and can be released from the Dask cluster."""
537- _logger .debug ("received %s result: %s" , f"{ task = } " , f"{ result = } " )
538- assert task .current .job_id # nosec
539- (
540- _service_key ,
541- _service_version ,
542- user_id ,
543- project_id ,
544- node_id ,
545- ) = parse_dask_job_id (task .current .job_id )
546-
547- assert task .current .project_id == project_id # nosec
548- assert task .current .node_id == node_id # nosec
549- log_error_context = {
550- "user_id" : user_id ,
551- "project_id" : project_id ,
552- "node_id" : node_id ,
553- "job_id" : task .current .job_id ,
554- }
555-
556- if isinstance (result , TaskOutputData ):
557- (
558- task_final_state ,
559- simcore_platform_status ,
560- task_errors ,
561- task_completed ,
562- ) = await self ._handle_successful_run (
563- task .current , result , log_error_context
564- )
506+ with log_context (
507+ _logger , logging .DEBUG , msg = f"{ comp_run .run_id = } , { task = } , { result = } "
508+ ):
509+ log_error_context = {
510+ "user_id" : comp_run .user_id ,
511+ "project_id" : f"{ comp_run .project_uuid } " ,
512+ "node_id" : f"{ task .current .node_id } " ,
513+ "job_id" : task .current .job_id ,
514+ }
515+
516+ if isinstance (result , TaskOutputData ):
517+ (
518+ task_final_state ,
519+ simcore_platform_status ,
520+ task_errors ,
521+ task_completed ,
522+ ) = await self ._handle_successful_run (
523+ task .current , result , log_error_context
524+ )
565525
566- elif isinstance (result , ComputationalBackendTaskResultsNotReadyError ):
567- (
568- task_final_state ,
569- simcore_platform_status ,
570- task_errors ,
571- task_completed ,
572- ) = await self ._handle_computational_retrieval_error (
573- task .current , user_id , result , log_error_context
574- )
575- elif isinstance (result , ComputationalBackendNotConnectedError ):
576- (
577- task_final_state ,
578- simcore_platform_status ,
579- task_errors ,
580- task_completed ,
581- ) = await self ._handle_computational_backend_not_connected_error (
582- task .current , result , log_error_context
583- )
584- else :
585- (
586- task_final_state ,
587- simcore_platform_status ,
588- task_errors ,
589- task_completed ,
590- ) = await self ._handle_task_error (task .current , result , log_error_context )
591-
592- # we need to remove any invalid files in the storage
593- await clean_task_output_and_log_files_if_invalid (
594- self .db_engine , user_id , project_id , node_id
595- )
526+ elif isinstance (result , ComputationalBackendTaskResultsNotReadyError ):
527+ (
528+ task_final_state ,
529+ simcore_platform_status ,
530+ task_errors ,
531+ task_completed ,
532+ ) = await self ._handle_computational_retrieval_error (
533+ task .current , result , log_error_context
534+ )
535+ elif isinstance (result , ComputationalBackendNotConnectedError ):
536+ (
537+ task_final_state ,
538+ simcore_platform_status ,
539+ task_errors ,
540+ task_completed ,
541+ ) = await self ._handle_computational_backend_not_connected_error (
542+ task .current , result , log_error_context
543+ )
544+ else :
545+ (
546+ task_final_state ,
547+ simcore_platform_status ,
548+ task_errors ,
549+ task_completed ,
550+ ) = await self ._handle_task_error (
551+ task .current , result , log_error_context
552+ )
596553
597- if task_completed :
598- # resource tracking
599- await publish_service_resource_tracking_stopped (
600- self .rabbitmq_client ,
601- ServiceRunID .get_resource_tracking_run_id_for_computational (
602- user_id , project_id , node_id , iteration
603- ),
604- simcore_platform_status = simcore_platform_status ,
605- )
606- # instrumentation
607- await publish_service_stopped_metrics (
608- self .rabbitmq_client ,
609- user_id = user_id ,
610- simcore_user_agent = run_metadata .get (
611- "simcore_user_agent" , UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE
612- ),
613- task = task .current ,
614- task_final_state = task_final_state ,
615- )
554+ # we need to remove any invalid files in the storage
555+ await clean_task_output_and_log_files_if_invalid (
556+ self .db_engine ,
557+ comp_run .user_id ,
558+ comp_run .project_uuid ,
559+ task .current .node_id ,
560+ )
616561
617- await CompTasksRepository (self .db_engine ).update_project_tasks_state (
618- task .current .project_id ,
619- run_id ,
620- [task .current .node_id ],
621- task_final_state if task_completed else task .previous .state ,
622- errors = task_errors ,
623- optional_progress = 1 if task_completed else None ,
624- optional_stopped = arrow .utcnow ().datetime if task_completed else None ,
625- )
562+ if task_completed :
563+ # resource tracking
564+ await publish_service_resource_tracking_stopped (
565+ self .rabbitmq_client ,
566+ ServiceRunID .get_resource_tracking_run_id_for_computational (
567+ comp_run .user_id ,
568+ comp_run .project_uuid ,
569+ task .current .node_id ,
570+ iteration ,
571+ ),
572+ simcore_platform_status = simcore_platform_status ,
573+ )
574+ # instrumentation
575+ await publish_service_stopped_metrics (
576+ self .rabbitmq_client ,
577+ user_id = comp_run .user_id ,
578+ simcore_user_agent = comp_run .metadata .get (
579+ "simcore_user_agent" , UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE
580+ ),
581+ task = task .current ,
582+ task_final_state = task_final_state ,
583+ )
584+
585+ await CompTasksRepository (self .db_engine ).update_project_tasks_state (
586+ task .current .project_id ,
587+ comp_run .run_id ,
588+ [task .current .node_id ],
589+ task_final_state if task_completed else task .previous .state ,
590+ errors = task_errors ,
591+ optional_progress = 1 if task_completed else None ,
592+ optional_stopped = arrow .utcnow ().datetime if task_completed else None ,
593+ )
626594
627- return task_completed , task .current .job_id
595+ return task_completed , task .current .job_id
628596
629597 async def _task_progress_change_handler (
630598 self , event : tuple [UnixTimestamp , Any ]
0 commit comments